用Flask在服务器上部署训练好的模型

2,899 阅读4分钟

本教程的三个部分是:

1、先训练好一个模型

2、构建Flask微框架所需的组件以创建Web app

3、运行web app

需要安装的组件

如果要从此页面复制和粘贴代码,请确保已安装以下组件:

Python 3.6+
python packages:
Flask
Pandas
Sklearn
Xgboost
Seaborn
Matplotlib

1、先训练好一个模型

模型这里我就不阐述了, 随便什么模型都可以,不管是用tensorflow还是用pytorch写的,也都行。

2、构建Flask微框架所需的组件以创建Web app

我们需要做一些事情将Web app整合在一起:

a、Python代码 包括载入我们训练好的模型,从Web表单获取用户输入,进行预测并返回结果

b、HTML模板 允许用户输入自己的数据并显示结果

该web app的初始结构如下:

image.png

首先我将创建一个非常基本的app.py和main.html,以演示flask如何工作。我们将在后面扩展程序以适应我们的需求

app.py

这是web app的核心。它将在服务器上运行,发送网页并处理用户的输入

import flask
app = flask.Flask(__name__, template_folder='templates')
@app.route('/')
def main():
    return(flask.render_template('main.html'))
if __name__ == '__main__':
    app.run(host = '0.0.0.0')

main.html

这就是前端的界面。它现在所做的只是显示一条简单的消息,我们稍后将对其进行编辑以适合我们的需求

<!doctype html>
<html>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
<head>
<title>Web app name</title>
</head>
<h1>Hello world.</h1>
</html>

Running the test app

要在服务器上先启动flask,打开终端,请确保您位于该webapp文件夹中,然后运行以下命令: flask run

image.png 然后按CTRL+C退出

再运行app.python文件 python app.python

image.png

然后在本地电脑的浏览器中输入(127.0.0.1:5000),由于我们是在服务器上部署的,所以应该把127.0.0.1替换成访问服务器的域名 http://202.116.46.215:5000/ ,假如还是不行,有可能是5000端口没有开

image.png

到这里我们的一个基本web app已经完成了,接下来我们将需要再修改程序,使其满足我们的需求。

我用我写的模型进行修改,你们只需要按照你们的模型进行相应的修改就行了 我训练好的是一个食物识别的模型,用tensorflow写的,我把训练好的模型保存到model/下

image.png

然后加在我的模型,加载模型的相关代码是

model = 'R50+ViT-B_16'
VisionTransformer = models.KNOWN_MODELS[model].partial(num_classes=172)
params = checkpoint.load(f'/model/best_model.npz')
params['pre_logits'] = {}  # Need to restore empty leaf for Flax.

如果你用的pytorch写的模型,加载模型的代码可以参考

# Use pickle to load in the pre-trained model.
with open(f'./model/bike_model_xgboost.pkl', 'rb') as f:
    model = pickle.load(f)

模型加载好了,我做的是图片识别,我要从前端获取上传的照片,然后做预测再返回

先把照片保存到文件夹下来,我在前端main.html文件中是使用标签<input id="file" name = "file" type= "file"/>来上传照片的

#当前绝对路径
basedir = os.path.abspath(os.path.dirname(__file__))
f = request.files.get('file')
# 获取安全的文件名 正常的文件名
filename = secure_filename(f.filename)


# f.filename.rsplit('.', 1)[1] 获取文件的后缀
# 把文件重命名
filename = datetime.now().strftime("%Y%m%d%H%M%S") + "." + "JPG"
print(filename)
# 保存的目标绝对地址
file_path = basedir + "/images/"
# 保存文件到目标文件夹
f.save(file_path + filename)

获取照片,做预测,返回结果

img = PIL.Image.open('./images/' + filename)
img = img.resize((384,384))
logits, = VisionTransformer.call(params, (np.array(img) / 128 - 1)[None, ...])
#后面就是做softmax,得到概率最大值的结果,然后返回预测结果
preds = flax.nn.softmax(logits)
labels = dict(enumerate(open('labels.txt'),start=1))
for idx in preds.argsort()[:-11:-1]:
    print(f'{preds[idx]:.5f} : {labels[idx+1]}', end='')
    predict = labels[idx+1]
    break
predict = predict[1:-1]
print(predict)
return flask.render_template('main.html', result = predict,)
 

整个app.py的代码

import flask
import pickle
import pandas as pd
from datetime import datetime
from flask import Flask, request, jsonify
import os
from werkzeug.utils import secure_filename
from vit_jax import models
from vit_jax import checkpoint
import flax
import PIL
import numpy as np

#当前绝对路径
basedir = os.path.abspath(os.path.dirname(__file__))

# Initialise the Flask app
app = flask.Flask(__name__, template_folder='templates')

# Set up the main route
@app.route('/', methods=['GET', 'POST'])
def main():
    if flask.request.method == 'GET':
        # Just render the initial form, to get input
        return(flask.render_template('main.html'))
    
    if flask.request.method == 'POST':
        f = request.files.get('file')
        # 获取安全的文件名 正常的文件名
        filename = secure_filename(f.filename)
        
        # f.filename.rsplit('.', 1)[1] 获取文件的后缀
        # 把文件重命名
        filename = datetime.now().strftime("%Y%m%d%H%M%S") + "." + "JPG"
        print(filename)
        # 保存的目标绝对地址
        file_path = basedir + "/images/"
        # 保存文件到目标文件夹
        f.save(file_path + filename)
        
        #加载模型
        model = 'R50+ViT-B_16'
        VisionTransformer = models.KNOWN_MODELS[model].partial(num_classes=172)
        params = checkpoint.load(f'./model/best_model.npz')
        params['pre_logits'] = {}  # Need to restore empty leaf for Flax.
 
        #读取图片,做预测,返回结果
        img = PIL.Image.open('./images/' + filename)
        img = img.resize((384,384))
        logits, = VisionTransformer.call(params, (np.array(img) / 128 - 1)[None, ...])
        labels = dict(enumerate(open('labels.txt'),start=1))
        preds = flax.nn.softmax(logits)
        for idx in preds.argsort()[:-11:-1]:
            print(f'{preds[idx]:.5f} : {labels[idx+1]}', end='')
            predict = labels[idx+1]
            break
        predict = predict[1:-1]
        return flask.render_template('main.html', result = predict,)
 
if __name__ == '__main__':
    app.run(host = '0.0.0.0')
    
    

app.py前后端交互就写好了,现在开始修改前端界面

由于我要的输入是一张图片,所以我要上传图片,如果你要的输入是文本,你就用文本框就行了

<!doctype html>
<html>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
<style>
form {
    margin: auto;
    width: 35%;
}

.result {
    margin: auto;
    width: 35%;
    border: 1px solid #ccc;
}
</style>

<head>
    <title>Food Recognition Model</title>
</head>
<form action="{{ url_for('main') }}" method="POST" enctype = "multipart/form-data">
    <fieldset>
        <legend>Input values:</legend>
        <label for = "file">文件名:</label>
        <input id="file" name = "file" type= "file"/>
        <input type="submit" name="submit" value="提交"/>
    </fieldset>
</form>
<br>
<div class="result" align="center">
    {% if result %}
      <br> The food is:
      <p style="font-size:50px">{{ result }}</p>
    {% endif %}
</div>
</html>

到这里,整个web app我们就完成了,现在赶快试一试吧

3、运行web app

在终端输入python app.py

然后再浏览器输入(换成自己的域名) http://202.116.46.215:5000/

image.png

d8e7540cff30865155590b17867f9029 (1).gif