基于SCUT-FBP5500 数据集的Facial Beauty prediction

90 阅读10分钟

代码

from flask import Flask, request, jsonify, send_from_directory
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from flask_cors import CORS
import os

# 定义模型结构
class AlexNet(nn.Module):
    def __init__(self, num_classes=1):
        super(AlexNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 96, kernel_size=11, stride=4, bias=False)
        self.relu_pool1 = bn_relu_pool(inplanes=96)
        self.conv2 = nn.Conv2d(96, 192, kernel_size=5, padding=2, groups=2, bias=False)
        self.relu_pool2 = bn_relu_pool(inplanes=192)
        self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1, groups=2, bias=False)
        self.relu3 = bn_relu(inplanes=384)
        self.conv4 = nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2, bias=False)
        self.relu4 = bn_relu(inplanes=384)
        self.conv5 = nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2, bias=False)
        self.relu_pool5 = bn_relu_pool(inplanes=256)
        # classifier
        self.conv6 = nn.Conv2d(256, 256, kernel_size=5, groups=2, bias=False)
        self.relu6 = bn_relu(inplanes=256)
        self.conv7 = nn.Conv2d(256, num_classes, kernel_size=1, bias=False)
        
    def forward(self, x):
        # Define the forward pass of your model
        x = self.conv1(x)
        x = self.relu_pool1(x)
        x = self.conv2(x)
        x = self.relu_pool2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.conv4(x)
        x = self.relu4(x)
        x = self.conv5(x)
        x = self.relu_pool5(x)
        x = self.conv6(x)
        x = self.relu6(x)
        x = self.conv7(x)
        x = x.view(x.size(0), -1)
        return x

def bn_relu(inplanes):
    return nn.Sequential(nn.BatchNorm2d(inplanes), nn.ReLU(inplace=True))

def bn_relu_pool(inplanes, kernel_size=3, stride=2):
    return nn.Sequential(nn.BatchNorm2d(inplanes), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=kernel_size, stride=stride))

# 加载预训练模型的权重到当前模型实例中
def load_model(model_path, model):
    print("Loading model...")
    pretrained_dict = torch.load(model_path, map_location=torch.device('cpu'), encoding='latin1')
    model_dict = model.state_dict()
    # Filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict['state_dict'].items() if k in model_dict}
    # Overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    print("Model loaded successfully!")

# 预处理
def preprocess_image(image):
    print("Preprocessing image...")
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(image)

# 初始化Flask应用
app = Flask(__name__)
CORS(app)
# 定义接受图片路由和返回预测路由
@app.route('/')
def serve_index():
    return send_from_directory('frontend', 'index.html')

@app.route('/predict', methods=['POST'])
def predict():
    print("Received POST request...")
    if 'file' not in request.files:
        return jsonify({'error': 'No file part'})
    file = request.files['file']
    if file.filename == '':
        return jsonify({'error': 'No selected file'})
    if file:
        try:
            # 加载模型
            model = AlexNet().cuda()
            load_model('facial_beauty/models/alexnet.pth', model)
            model.eval()  
            # 预处理客户端传送过来图片
            image = Image.open(file).convert('RGB')
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
            image_tensor = transform(image).unsqueeze(0).cuda()
            # 预测
            with torch.no_grad():
                print("Making predictions...")
                output = model(image_tensor)
                print("Predictions made successfully!")
            prediction = output.item()
            return jsonify({'prediction': prediction})
        except Exception as e:
            error_message = str(e).encode('utf-8')  # Encode error message as UTF-8
            return jsonify({'error': error_message.decode('utf-8')})  # Decode error message for JSON serialization


if __name__ == '__main__':
    app.run(debug=True)

笔记

  1. flask 是 Flask 框架的模块名; Flask 是 Flask 模块中的一个类。flask_cors是 Flask 的一个扩展,用于简化设置 CORS 规则。

    通过 CORS(app) 这行代码,可以自动为所有路由启用 CORS 支持,允许跨域请求。

#app是Flask类的一个实例
app = Flask(__name__)
CORS(app)

@app.route('/')
def serve_index():
    return send_from_directory('frontend', 'index.html')
@app.route('/predict', methods=['POST'])
def predict():
    ...
  • send_from_directory 是 Flask 提供的一个方法,用于从指定的目录中发送文件给客户端。它通常用于提供静态文件,如 HTML 文件、图像、CSS、JavaScript 文件等。send_from_directory('downloads', 'filename')``downloads 是文件所在的目录,filenname表示文件名。send_from_directory() 是一个安全函数,确保只能从指定目录发送文件,避免目录遍历攻击。如果需要提供大量静态文件,最好使用 Flask 的 static_folder 配置选项,或者使用专门的静态文件服务器,如 Nginx 或 Apache。
  • @app.route('/') 是 Flask 的路由装饰器,用于将 URL 路径映射到视图函数。当客户端访问特定 URL 时,Flask 会调用写在 @route() 装饰器下面的视图函数。URL 路径可以包含动态部分,用尖括号表示(见下面示例)。
  • serve_index()函数作为视图函数,通过路由装饰器 @app.route('/') 调用,当用户访问根路径 / 时,Flask 会调用 serve_index根目录是指 URL 路径的起始点,在大多数情况下,是Web 应用程序的首页。
@app.route('/user/<username>') def show_user_profile(username): return f'User {username}'

route()常见用法:

  • 使用 request 处理表单数据。

  • 使用 jsonify 返回 JSON 响应。

  • 启用了 Flask 的调试模式(debug=True),方便在开发过程中查看详细的错误信息。

@app.route('/submit', methods=['POST']) 
def submit(): 
    data = request.form['data'] 
    return f'Received data: {data}' 

@app.route('/api/data') 
def get_data(): 
    data = {"name": "Alice", "age": 30} 
    return jsonify(data) 

if __name__ == '__main__': app.run(debug=True)

route()常用参数:

  • rule:URL 路径。可以包含动态部分(变量规则)。
  • methods:指定允许的 HTTP 方法,默认是 ['GET']。常见方法包括 GETPOSTPUTDELETE 等。
  • endpoint:视图函数的别名,默认是视图函数的名称。
  • strict_slashes:是否严格区分 URL 末尾的斜杠,默认为 True
  1. PIL 是 Python Imaging Library 的缩写,是一个 Python 图像处理库,用于打开、操作和保存多种格式的图像文件。
  • convert() 是 PIL 中的一个方法,用于转换图像的颜色模式。比如convert('RGB') 将图像转换为 RGB 颜色模式,即红、绿、蓝三通道。
  1. __init__() 是 Python 类的构造函数,初始化对象时会自动调用。定义了模型的结构。比如:__init__中定义了conv1...conv7、relu_pool1...relu_pool5、relu1...relu6等结构,forward中可以直接self.conv1调用。

  2. nn.Sequential 是 PyTorch 中的一个容器模块,可以将多个层串联在一起。比如:

nn.Sequential(nn.BatchNorm2d(inplanes), nn.ReLU(inplace=True))

包含BatchNorm2dReLU 层。

学会将多个层串联在一起简化代码:

class AlexNet(nn.Module):
    def __init__(self, num_classes=1):
        super(AlexNet, self).__init__()
        self.relu_pool1 = bn_relu_pool(inplanes=96)
        self.relu_pool5 = bn_relu_pool(inplanes=256)
        ...
        
        
def bn_relu(inplanes):
    return nn.Sequential(nn.BatchNorm2d(inplanes), nn.ReLU(inplace=True))

def bn_relu_pool(inplanes, kernel_size=3, stride=2):
    return nn.Sequential(nn.BatchNorm2d(inplanes), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=kernel_size, stride=stride))
  • map_location=torch.device('cpu'):是为了确保无论模型是在 GPU 上训练还是保存的,都可以将其加载到 CPU 上使用
  • cuda() 是 PyTorch 的一个方法,用于将模型或张量移动到 GPU 设备上。比如:
model = AlexNet().cuda()
  • encoding='latin1':指定文件的编码方式为 latin1,用于处理一些特殊字符的情况。其他比如:utf-8、ASCII、GBK...
  • state_dict() 是 PyTorch 中的一个方法,是一种用于获取模型状态字典的方法。模型的状态字典是一个 Python 字典对象,它将每一层的参数(权重和偏置)映射到对应的 Tensor。
  1. 字典推导式:{key: value for key, value in iterable if condition},生成一个只包含满足条件的键值对的新字典。

    结构解析:

  • {key: value}:生成的新字典中的键值对。

  • for key, value in iterable:迭代一个可迭代对象,该可迭代对象应当返回键值对。

  • if condition:一个可选的条件表达式,只有满足条件的键值对才会被包括在生成的字典中。

    示例:

  • 带条件的字典推导式:

even_squares = {x: x**2 for x in range(1, 11) if x % 2 == 0}
print(even_squares)
# 输出: {2: 4, 4: 16, 6: 36, 8: 64, 10: 100}
  • 从两个列表创建字典:
keys = ['a', 'b', 'c']
values = [1, 2, 3]
combined_dict = {k: v for k, v in zip(keys, values)}
print(combined_dict)
# 输出: {'a': 1, 'b': 2, 'c': 3}
  • 基于已有字典生成新字典:
original_dict = {'a': 1, 'b': 2, 'c': 3}
filtered_dict = {k: v for k, v in original_dict.items() if v > 1}
print(filtered_dict)
# 输出: {'b': 2, 'c': 3}
  • pretrained_dict['state_dict'].items() 返回 pretrained_dictstate_dict 的所有键值对。

  • update() 是 Python 字典的一个方法,用于更新字典,将另一个字典中的键值对合并到当前字典中。

  • 用另一个字典更新:

dict1 = {'a': 1, 'b': 2}
dict2 = {'b': 3, 'c': 4}
dict1.update(dict2)
print(dict1)  # 输出: {'a': 1, 'b': 3, 'c': 4}
  • 用键值对的列表更新:
dict1 = {'a': 1, 'b': 2}
dict1.update([('b', 3), ('c', 4)])
print(dict1)  # 输出: {'a': 1, 'b': 3, 'c': 4}
  • 用关键字参数更新:
dict1 = {'a': 1, 'b': 2}
dict1.update(b=3, c=4)
print(dict1)  # 输出: {'a': 1, 'b': 3, 'c': 4}
  • load_state_dict() 是 PyTorch 模型的一个方法,用于将一个 state_dict 加载到模型中,替换模型当前的参数
def load_model(model_path, model):
    print("Loading model...")
    pretrained_dict = torch.load(model_path, map_location=torch.device('cpu'), encoding='latin1')
    model_dict = model.state_dict()
    # 过滤掉预训练模型中那些在当前模型中不存在的参数,避免因为预训练模型和当前模型结构不完全匹配而导致的错误
    pretrained_dict = {k: v for k, v in pretrained_dict['state_dict'].items() if k in model_dict}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    print("Model loaded successfully!")
  1. methods:指定允许的 HTTP 方法,如 GETPOSTPUTDELETE 等。
  • GET:用于从服务器获取资源,可以理解为 "读取" 操作;GET 请求是幂等的,即多次请求相同的资源返回的结果应该是一致的,不应该对资源状态有任何影响;适用于获取页面、图像、文件等资源,获取某个资源的详细信息或列表,向服务器查询信息。

  • POST:提交数据给服务器处理,通常用于创建新资源;POST 请求在 HTTP 主体中传递数据,而不是像 GET 请求那样在 URL 中传递;POST 请求不是幂等的,即多次请求可能会导致不同的结果或影响资源状态;适用于提交表单数据,发布博客、文章等内容,进行用户登录、注册等操作,传输敏感信息。

  • PUT:更新服务器上的资源;通常将整个请求体作为更新后的资源,而不是部分更新;是幂等的,多次请求相同的资源状态应该保持一致;适用于更新特定资源的全部信息,提交文件或大块数据

  • DELETE:请求服务器删除指定的资源;请求是幂等的,即多次请求删除同一资源的结果应该是相同的;适用于删除指定的资源,如删除文件、取消订单、删除用户等。

    HTTP主题传递数据与URL中传递数据:

  • URL 中传递数据:数据附加在 URL 的查询字符串部分,比如:http://example.com/api/resource?key1=value1&key2=value2;URL 的长度有限制,适合传递少量数据;数据暴露在 URL 中,容易被记录在浏览器历史、日志文件和书签中,传递敏感信息时安全性较低;适用于 GET 请求。

  • HTTP 主体中传递数据:数据包含在 HTTP 请求的主体部分,适用于 POST、PUT 等请求方法,比如:POST /api/resource HTTP/1.1,并在请求体中包含 { "key1": "value1", "key2": "value2" };请求体的数据量可以很大,适合传递大量数据或文件;数据在请求体中,不会暴露在 URL 中,安全性相对较高;适用于 POST、PUT 等请求。

if 'file' not in request.files:
    return jsonify({'error': 'No file part'})
file = request.files['file']
if file.filename == '':
    return jsonify({'error': 'No selected file'})
  • request.files:Flask 提供的一个对象,包含上传的文件

  • jsonify:Flask 提供的一个方法,用于创建 JSON 响应。

  1. Python 的异常处理结构,用于捕获和处理运行时错误。
try:
    # 可能会引发异常的代码
except Exception as e:
    # 处理异常的代码
finally:
    # 无论是否发生异常都会执行的代码

try:
    ...
    
except Exception as e:
    error_message = str(e).encode('utf-8')  
    return jsonify({'error': error_message.decode('utf-8')})
  • Exception 是 Python 内建的所有异常的基类

  • e 是捕获的异常实例。

  • encode() 是字符串方法,将字符串编码为指定编码的字节序列。

  • decode() 是字节序列方法,将字节序列解码为指定编码的字符串。

  • Python 的 str 类型是 Unicode 字符串,在处理文件、网络数据等时,需要根据具体需求选择合适的编码方式进行转换。

  • JSON 规范允许使用多种 Unicode 编码,但JSON 的默认和最常用编码是 UTF-8。

  1. if __name__ == '__main__': app.run(debug=True)
  • 检查当前模块是否是主程序,只有在模块作为主程序执行时,才会运行 app.run()

  • debug=True 启用调试模式,提供详细的错误信息,并在代码修改时自动重启服务器。