Are you trying to deploy a machine learning model and wondering how?
Deploying machine learning models is possible with Flask, a popular Python web framework.
In this tutorial, I will show how to deploy machine learning models using Flask.
Which Python Modules to Use For Machine Learning?
Before starting, you must install a few dependencies on your computer to train a machine learning model and use Flask to communicate with the trained model.
Install the following dependencies:
- Python
- Flask
- Scikit-learn
- pickle
Visit Python’s official website, download it, and install it on your computer. To install Flask, Scikit-learn, and the pickle module use the following commands in your command line interface:
pip install Flask
pip install scikit-learn
pip install pickle
How to Build a Machine Learning Model Using Python
We will start by training a machine learning model using the famous Iris dataset. We will train a classification model which predicts the species of this flower.
Here is what the CSV dataset looks like:
sepal.length,sepal.width,petal.length,petal.width,species
5.1,3.5,1.4,0.2,Setosa
4.9,3,1.4,0.2,Setosa
4.7,3.2,1.3,0.2,Setosa
4.6,3.1,1.5,0.2,Setosa
5,3.6,1.4,0.2,Setosa
To train the model, we will use the scikit-learn library.
Here is the complete code to obtain a pickled model from the dataset.
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import pickle
# Reading the data
iris = pd.read_csv("iris.csv")
print(iris.head())
y = iris['species']
iris.drop(columns='species', inplace=True)
x = iris[['sepal.length', 'sepal.width', 'petal.length', 'petal.width']]
# Training the model
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3)
model = LogisticRegression(max_iter=100)
model.fit(x_train, y_train)
pickle.dump(model, open('model.pkl', 'wb'))
In the code above, we do the following:
- Import the Python modules we need
- Read the data from the Iris dataset
- Train the model and store it in a file using the pickle module
We import the dataset from a CSV file into a Pandas dataframe, and we split the dataset into test sets and training sets. The training set is used to train the model and the test set to test our model.
By printing iris.head() we can see the first five rows of the dataframe:
sepal.length sepal.width petal.length petal.width species
0 5.1 3.5 1.4 0.2 Setosa
1 4.9 3.0 1.4 0.2 Setosa
2 4.7 3.2 1.3 0.2 Setosa
3 4.6 3.1 1.5 0.2 Setosa
4 5.0 3.6 1.4 0.2 Setosa
We are training a Logistic Regression model on our dataset. It is a statistical method used to estimate the likelihood of a binary outcome (i.e., a variable that can have one of two possible values, such as true/false or yes/no). In our case, we are training the model to predict the species of the iris flowers.
We import the Python pickle module, and at the end of the code, you can see that we save the trained model into a pickle file (model.pkl).
We will now use a Flask API to communicate with the trained model.
How Do You Deploy a Machine Learning Model With Flask?
Now let’s see how to communicate with the machine learning model using Flask.
Here is the code for our Flask application:
import numpy as np
from flask import Flask, request, jsonify, render_template
import pickle
app = Flask(__name__) # Initialize the flask App
model = pickle.load(open('model.pkl', 'rb')) # Load the trained model
@app.route('/') # Homepage
def home():
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
# UI rendering the results
# Retrieve values from a form
init_features = [float(x) for x in request.form.values()]
final_features = [np.array(init_features)]
prediction = model.predict(final_features) # Make a prediction
return render_template('index.html', prediction_text='Predicted Species: {}'.format(prediction)) # Render the predicted result
if __name__ == "__main__":
app.run(debug=True)
In this code, we define the main endpoint for the homepage and a prediction endpoint.
This Flask web application loads the pre-trained machine learning model and uses it to make predictions based on user input via a web form.
Before executing the code above, make sure to create a simple Web UI that has all the input fields for the values to be passed to the model. The output returned by the model will be shown on the UI.
Here is the user interface we will build for the web app (you can customize the look by adding a CSS).
The Web UI above takes all the inputs passed via the form and sends them to the model using the POST endpoint of the Flask API.
Then the model makes predictions and sends the results back to the UI.
Here is the directory structure of the project. This will help you make sure the files you need to make this work are in the right location.
.
├── app.py
├── create_machine_learning_model.py
├── iris.csv
├── model.pkl
└── templates
└── index.html
1 directory, 5 files
And here is the HTML code of the simple Web UI:
<html >
<head>
<title>Machine Learning Model Deployment</title>
</head>
<body>
<div>
<h1>Predict Iris Species</h1>
<form action="{{ url_for('predict')}}" method="post">
<input type="text" name="sepal_length" placeholder="Sepal Length (cm)" required="required" />
<input type="text" name="sepal_width" placeholder="Sepal Width (cm)" required="required" />
<input type="text" name="petal_length" placeholder="Petal Length (cm)" required="required" />
<input type="text" name="petal_width" placeholder="Petal Width (cm)" required="required" />
<button type="submit">Predict</button>
</form>
<br>
{{ prediction_text }}
</div>
</body>
</html>
Now open your command line and access the directory where you created the project. Then enter the following command:
python app.py
You will see an output similar to the one below:
* Serving Flask app "app" (lazy loading)
* Environment: production
WARNING: This is a development server. Do not use it in a production deployment.
Use a production WSGI server instead.
* Debug mode: on
* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
* Restarting with fsevents reloader
* Debugger is active!
* Debugger PIN: 334-659-046
The web app will run on http://localhost:5000
. Now you can communicate with the trained machine learning model using the web app.
Erorr lbfgs failed to converge: How Can You Fix It?
While executing the Python program that generates the pickle model you might have seen the error “lbfgs failed to converge“.
This can occur when calling LogisticRegression from sklearn.linear_model.
Here is what the error looks like:
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
Increase the number of iterations (max_iter) or scale the data as shown in:
https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
n_iter_i = _check_optimize_result(
If you see this error you can solve it by increasing the value of the max_iter argument you pass to LogisticRegression classifier.
You can replicate this error by updating the following line of code:
model = LogisticRegression(max_iter=100)
And update the max_iter value to 50.
Conclusion
With this tutorial, you learned how to deploy a machine learning model with Flask. You have seen that it’s a simple procedure.
You can provide users with access to your machine learning model by building a Flask application, loading the trained model, specifying a prediction function, and developing an API endpoint.
Flask makes it simple to build apps that are scalable, effective, and capable of serving predictions to thousands of users.
Related article: in this tutorial, we have used the pickle module. Read the CodeFatherTech article that covers Python’s pickle module more in-depth.
Claudio Sabato is an IT expert with over 15 years of professional experience in Python programming, Linux Systems Administration, Bash programming, and IT Systems Design. He is a professional certified by the Linux Professional Institute.
With a Master’s degree in Computer Science, he has a strong foundation in Software Engineering and a passion for robotics with Raspberry Pi.