代码
from flask import Flask, request, jsonify, send_from_directory
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from flask_cors import CORS
import os
# 定义模型结构
class AlexNet(nn.Module):
def __init__(self, num_classes=1):
super(AlexNet, self).__init__()
self.conv1 = nn.Conv2d(3, 96, kernel_size=11, stride=4, bias=False)
self.relu_pool1 = bn_relu_pool(inplanes=96)
self.conv2 = nn.Conv2d(96, 192, kernel_size=5, padding=2, groups=2, bias=False)
self.relu_pool2 = bn_relu_pool(inplanes=192)
self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1, groups=2, bias=False)
self.relu3 = bn_relu(inplanes=384)
self.conv4 = nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2, bias=False)
self.relu4 = bn_relu(inplanes=384)
self.conv5 = nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2, bias=False)
self.relu_pool5 = bn_relu_pool(inplanes=256)
# classifier
self.conv6 = nn.Conv2d(256, 256, kernel_size=5, groups=2, bias=False)
self.relu6 = bn_relu(inplanes=256)
self.conv7 = nn.Conv2d(256, num_classes, kernel_size=1, bias=False)
def forward(self, x):
# Define the forward pass of your model
x = self.conv1(x)
x = self.relu_pool1(x)
x = self.conv2(x)
x = self.relu_pool2(x)
x = self.conv3(x)
x = self.relu3(x)
x = self.conv4(x)
x = self.relu4(x)
x = self.conv5(x)
x = self.relu_pool5(x)
x = self.conv6(x)
x = self.relu6(x)
x = self.conv7(x)
x = x.view(x.size(0), -1)
return x
def bn_relu(inplanes):
return nn.Sequential(nn.BatchNorm2d(inplanes), nn.ReLU(inplace=True))
def bn_relu_pool(inplanes, kernel_size=3, stride=2):
return nn.Sequential(nn.BatchNorm2d(inplanes), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=kernel_size, stride=stride))
# 加载预训练模型的权重到当前模型实例中
def load_model(model_path, model):
print("Loading model...")
pretrained_dict = torch.load(model_path, map_location=torch.device('cpu'), encoding='latin1')
model_dict = model.state_dict()
# Filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict['state_dict'].items() if k in model_dict}
# Overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
print("Model loaded successfully!")
# 预处理
def preprocess_image(image):
print("Preprocessing image...")
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(image)
# 初始化Flask应用
app = Flask(__name__)
CORS(app)
# 定义接受图片路由和返回预测路由
@app.route('/')
def serve_index():
return send_from_directory('frontend', 'index.html')
@app.route('/predict', methods=['POST'])
def predict():
print("Received POST request...")
if 'file' not in request.files:
return jsonify({'error': 'No file part'})
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No selected file'})
if file:
try:
# 加载模型
model = AlexNet().cuda()
load_model('facial_beauty/models/alexnet.pth', model)
model.eval()
# 预处理客户端传送过来图片
image = Image.open(file).convert('RGB')
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_tensor = transform(image).unsqueeze(0).cuda()
# 预测
with torch.no_grad():
print("Making predictions...")
output = model(image_tensor)
print("Predictions made successfully!")
prediction = output.item()
return jsonify({'prediction': prediction})
except Exception as e:
error_message = str(e).encode('utf-8') # Encode error message as UTF-8
return jsonify({'error': error_message.decode('utf-8')}) # Decode error message for JSON serialization
if __name__ == '__main__':
app.run(debug=True)
笔记
-
flask是 Flask 框架的模块名;Flask是 Flask 模块中的一个类。flask_cors是 Flask 的一个扩展,用于简化设置 CORS 规则。通过
CORS(app)这行代码,可以自动为所有路由启用 CORS 支持,允许跨域请求。
#app是Flask类的一个实例
app = Flask(__name__)
CORS(app)
@app.route('/')
def serve_index():
return send_from_directory('frontend', 'index.html')
@app.route('/predict', methods=['POST'])
def predict():
...
send_from_directory是 Flask 提供的一个方法,用于从指定的目录中发送文件给客户端。它通常用于提供静态文件,如 HTML 文件、图像、CSS、JavaScript 文件等。send_from_directory('downloads', 'filename')``downloads是文件所在的目录,filenname表示文件名。send_from_directory()是一个安全函数,确保只能从指定目录发送文件,避免目录遍历攻击。如果需要提供大量静态文件,最好使用 Flask 的static_folder配置选项,或者使用专门的静态文件服务器,如 Nginx 或 Apache。@app.route('/')是 Flask 的路由装饰器,用于将 URL 路径映射到视图函数。当客户端访问特定 URL 时,Flask 会调用写在@route()装饰器下面的视图函数。URL 路径可以包含动态部分,用尖括号表示(见下面示例)。- serve_index()函数作为视图函数,通过路由装饰器
@app.route('/')调用,当用户访问根路径/时,Flask 会调用serve_index。根目录是指 URL 路径的起始点,在大多数情况下,是Web 应用程序的首页。
@app.route('/user/<username>') def show_user_profile(username): return f'User {username}'
route()常见用法:
-
使用
request处理表单数据。 -
使用
jsonify返回 JSON 响应。 -
启用了 Flask 的调试模式(
debug=True),方便在开发过程中查看详细的错误信息。
@app.route('/submit', methods=['POST'])
def submit():
data = request.form['data']
return f'Received data: {data}'
@app.route('/api/data')
def get_data():
data = {"name": "Alice", "age": 30}
return jsonify(data)
if __name__ == '__main__': app.run(debug=True)
route()常用参数:
rule:URL 路径。可以包含动态部分(变量规则)。methods:指定允许的 HTTP 方法,默认是['GET']。常见方法包括GET、POST、PUT、DELETE等。endpoint:视图函数的别名,默认是视图函数的名称。strict_slashes:是否严格区分 URL 末尾的斜杠,默认为True。
- PIL 是 Python Imaging Library 的缩写,是一个 Python 图像处理库,用于打开、操作和保存多种格式的图像文件。
convert()是 PIL 中的一个方法,用于转换图像的颜色模式。比如convert('RGB')将图像转换为 RGB 颜色模式,即红、绿、蓝三通道。
-
__init__()是 Python 类的构造函数,初始化对象时会自动调用。定义了模型的结构。比如:__init__中定义了conv1...conv7、relu_pool1...relu_pool5、relu1...relu6等结构,forward中可以直接self.conv1调用。 -
nn.Sequential是 PyTorch 中的一个容器模块,可以将多个层串联在一起。比如:
nn.Sequential(nn.BatchNorm2d(inplanes), nn.ReLU(inplace=True))
包含BatchNorm2d 和 ReLU 层。
学会将多个层串联在一起简化代码:
class AlexNet(nn.Module):
def __init__(self, num_classes=1):
super(AlexNet, self).__init__()
self.relu_pool1 = bn_relu_pool(inplanes=96)
self.relu_pool5 = bn_relu_pool(inplanes=256)
...
def bn_relu(inplanes):
return nn.Sequential(nn.BatchNorm2d(inplanes), nn.ReLU(inplace=True))
def bn_relu_pool(inplanes, kernel_size=3, stride=2):
return nn.Sequential(nn.BatchNorm2d(inplanes), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=kernel_size, stride=stride))
map_location=torch.device('cpu'):是为了确保无论模型是在 GPU 上训练还是保存的,都可以将其加载到 CPU 上使用。cuda()是 PyTorch 的一个方法,用于将模型或张量移动到 GPU 设备上。比如:
model = AlexNet().cuda()
encoding='latin1':指定文件的编码方式为latin1,用于处理一些特殊字符的情况。其他比如:utf-8、ASCII、GBK...state_dict()是 PyTorch 中的一个方法,是一种用于获取模型状态字典的方法。模型的状态字典是一个 Python 字典对象,它将每一层的参数(权重和偏置)映射到对应的 Tensor。
-
字典推导式:
{key: value for key, value in iterable if condition},生成一个只包含满足条件的键值对的新字典。结构解析:
-
{key: value}:生成的新字典中的键值对。 -
for key, value in iterable:迭代一个可迭代对象,该可迭代对象应当返回键值对。 -
if condition:一个可选的条件表达式,只有满足条件的键值对才会被包括在生成的字典中。示例:
-
带条件的字典推导式:
even_squares = {x: x**2 for x in range(1, 11) if x % 2 == 0}
print(even_squares)
# 输出: {2: 4, 4: 16, 6: 36, 8: 64, 10: 100}
- 从两个列表创建字典:
keys = ['a', 'b', 'c']
values = [1, 2, 3]
combined_dict = {k: v for k, v in zip(keys, values)}
print(combined_dict)
# 输出: {'a': 1, 'b': 2, 'c': 3}
- 基于已有字典生成新字典:
original_dict = {'a': 1, 'b': 2, 'c': 3}
filtered_dict = {k: v for k, v in original_dict.items() if v > 1}
print(filtered_dict)
# 输出: {'b': 2, 'c': 3}
-
pretrained_dict['state_dict'].items()返回pretrained_dict中state_dict的所有键值对。 -
update()是 Python 字典的一个方法,用于更新字典,将另一个字典中的键值对合并到当前字典中。 -
用另一个字典更新:
dict1 = {'a': 1, 'b': 2}
dict2 = {'b': 3, 'c': 4}
dict1.update(dict2)
print(dict1) # 输出: {'a': 1, 'b': 3, 'c': 4}
- 用键值对的列表更新:
dict1 = {'a': 1, 'b': 2}
dict1.update([('b', 3), ('c', 4)])
print(dict1) # 输出: {'a': 1, 'b': 3, 'c': 4}
- 用关键字参数更新:
dict1 = {'a': 1, 'b': 2}
dict1.update(b=3, c=4)
print(dict1) # 输出: {'a': 1, 'b': 3, 'c': 4}
load_state_dict()是 PyTorch 模型的一个方法,用于将一个 state_dict 加载到模型中,替换模型当前的参数。
def load_model(model_path, model):
print("Loading model...")
pretrained_dict = torch.load(model_path, map_location=torch.device('cpu'), encoding='latin1')
model_dict = model.state_dict()
# 过滤掉预训练模型中那些在当前模型中不存在的参数,避免因为预训练模型和当前模型结构不完全匹配而导致的错误
pretrained_dict = {k: v for k, v in pretrained_dict['state_dict'].items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
print("Model loaded successfully!")
methods:指定允许的 HTTP 方法,如GET、POST、PUT、DELETE等。
-
GET:用于从服务器获取资源,可以理解为 "读取" 操作;GET 请求是幂等的,即多次请求相同的资源返回的结果应该是一致的,不应该对资源状态有任何影响;适用于获取页面、图像、文件等资源,获取某个资源的详细信息或列表,向服务器查询信息。
-
POST:提交数据给服务器处理,通常用于创建新资源;POST 请求在 HTTP 主体中传递数据,而不是像 GET 请求那样在 URL 中传递;POST 请求不是幂等的,即多次请求可能会导致不同的结果或影响资源状态;适用于提交表单数据,发布博客、文章等内容,进行用户登录、注册等操作,传输敏感信息。
-
PUT:更新服务器上的资源;通常将整个请求体作为更新后的资源,而不是部分更新;是幂等的,多次请求相同的资源状态应该保持一致;适用于更新特定资源的全部信息,提交文件或大块数据。
-
DELETE:请求服务器删除指定的资源;请求是幂等的,即多次请求删除同一资源的结果应该是相同的;适用于删除指定的资源,如删除文件、取消订单、删除用户等。
HTTP主题传递数据与URL中传递数据:
-
URL 中传递数据:数据附加在 URL 的查询字符串部分,比如:
http://example.com/api/resource?key1=value1&key2=value2;URL 的长度有限制,适合传递少量数据;数据暴露在 URL 中,容易被记录在浏览器历史、日志文件和书签中,传递敏感信息时安全性较低;适用于 GET 请求。 -
HTTP 主体中传递数据:数据包含在 HTTP 请求的主体部分,适用于 POST、PUT 等请求方法,比如:
POST /api/resource HTTP/1.1,并在请求体中包含{ "key1": "value1", "key2": "value2" };请求体的数据量可以很大,适合传递大量数据或文件;数据在请求体中,不会暴露在 URL 中,安全性相对较高;适用于 POST、PUT 等请求。
if 'file' not in request.files:
return jsonify({'error': 'No file part'})
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No selected file'})
-
request.files:Flask 提供的一个对象,包含上传的文件。 -
jsonify:Flask 提供的一个方法,用于创建 JSON 响应。
- Python 的异常处理结构,用于捕获和处理运行时错误。
try:
# 可能会引发异常的代码
except Exception as e:
# 处理异常的代码
finally:
# 无论是否发生异常都会执行的代码
try:
...
except Exception as e:
error_message = str(e).encode('utf-8')
return jsonify({'error': error_message.decode('utf-8')})
-
Exception是 Python 内建的所有异常的基类。 -
e是捕获的异常实例。 -
encode()是字符串方法,将字符串编码为指定编码的字节序列。 -
decode()是字节序列方法,将字节序列解码为指定编码的字符串。 -
Python 的
str类型是 Unicode 字符串,在处理文件、网络数据等时,需要根据具体需求选择合适的编码方式进行转换。 -
JSON 规范允许使用多种 Unicode 编码,但JSON 的默认和最常用编码是 UTF-8。
if __name__ == '__main__': app.run(debug=True)
-
检查当前模块是否是主程序,只有在模块作为主程序执行时,才会运行
app.run()。 -
debug=True启用调试模式,提供详细的错误信息,并在代码修改时自动重启服务器。