基于Alexnet深度学习神经网络的人脸识别算法matlab仿真

111 阅读5分钟

1.算法理论概述

        人脸识别是计算机视觉领域中一个重要的研究方向,其目的是识别不同人的面部特征以实现自动身份识别。随着深度学习神经网络的发展,基于深度学习神经网络的人脸识别算法已经成为了当前最先进的人脸识别技术之一。本文将详细介绍基于AlexNet深度学习神经网络的人脸识别算法的实现步骤和数学公式。

 

1.1数据预处理

 

       在进行人脸识别之前,需要进行数据预处理,将原始的人脸图像转换为可以被深度学习神经网络处理的格式。数据预处理的步骤包括图像裁剪、大小归一化、灰度化和像素值标准化等。其中,图像裁剪是指将原始图像中的人脸部分裁剪出来,大小归一化是指将裁剪后的人脸图像大小调整为固定大小,灰度化是指将彩色图像转换为灰度图像,像素值标准化是指将灰度图像的像素值进行归一化处理,以便于神经网络学习。

 

1.2神经网络架构

 

         采用AlexNet深度学习神经网络进行人脸识别。AlexNet是一个经典的卷积神经网络,由Alex Krizhevsky、Ilya Sutskever和Geoffrey Hinton在2012年提出。其由5个卷积层、3个全连接层和最终的分类器层组成。AlexNet的架构如下所示:

 

         其中,输入层接收大小为227×227×3的人脸图像,第一个卷积层提取96个特征图,每个特征图大小为55×55,步长为4,对应的卷积核大小为11×11×3。第二个卷积层提取256个特征图,每个特征图大小为27×27,步长为1,对应的卷积核大小为5×5×48。第三个、第四个和第五个卷积层分别提取384个、384个和256个特征图,每层特征图大小和步长与第二个卷积层相同。最后,全连接层和分类器层对提取的特征进行分类。

 

1.3损失函数

       本文采用softmax交叉熵损失函数进行训练。softmax交叉熵损失函数的数学公式如下所示:

245ce550a297c47367f3e0de11865d73_82780907_202307301312490381590489_Expires=1690694569&Signature=tm%2BDT%2Bx4qgThT9kN3AHqPzCGGlA%3D&domain=8.png  

       其中,NN表示样本数量,MM表示类别数量,yijy_{ij}表示第ii个样本的真实标签,y^ij\hat{y}_{ij}表示第ii个样本在第jj个类别上的预测概率。

 

1.4训练过程

 

       采用随机梯度下降法进行训练。具体来说,每次从训练集中随机选择一个batch的样本,将其输入神经网络中进行前向传播,得到每个类别的预测概率。然后,根据预测结果和真实标签计算损失函数,并利用反向传播算法计算每个参数的梯度。最后,根据梯度更新参数,并重复以上步骤直到达到指定的训练轮数或者达到收敛条件。

 

1.5测试过程

       在测试过程中,将测试集中的每个样本输入训练好的神经网络中,得到每个类别的预测概率。然后,根据预测概率选择概率最大的类别作为该样本的预测标签。最后,将预测标签和真实标签进行比对,计算准确率、召回率、F1值等评价指标。

 

 

 

2.算法运行软件版本

MATLAB2021a

 

3.     算法运行效果图预览

 

2.png

3.png  

4.部分核心程序 `trainingOptions("rmsprop","InitialLearnRate",learning_rate,'MaxEpochs',100,'MiniBatchSize',16,'Plots','training-progress');

 % 使用 Train 训练网络,得到新的网络模型 newnet 和训练信息 info

 [newnet,info]    = trainNetwork(Train, ly, opts);% 对测试集的图像进行分类,得到分类结果 predict 和分类概率 scores

 [predict,scores] = classify(newnet,Test);% 对测试集的图像进行分类,得到分类结果 predict 和分类概率 scores

 names  = Test.Labels; % 获取测试集中的标签

 pred   = (predict==names);% 判断分类结果是否正确,得到一个逻辑数组 pred

 s      = size(pred);% 获取 pred 的大小

 acc    = sum(pred)/s(1); % 计算分类准确率 acc

 fprintf('The accuracy of the test set is %f %% \NUM',acc*100);% 打印测试集的分类准确率

 

 

 

 

nameofs01 = '1';

nameofs02 = '2';

 

 

% 加载待分类的图像,并进行分类

img11     = imread('11.jpg');

img11     = imresize(img11,[227 227]);

predict11 = classify(newnet,img11);

img12     = imread('12.jpg');

img12     = imresize(img12,[227 227]);

predict12 = classify(newnet,img12);

img13     = imread('13.jpg');

img13     = imresize(img13,[227 227]);

predict13 = classify(newnet,img13);

 

 

img21     = imread('21.jpg');

img21     = imresize(img21,[227 227]);

predict21 = classify(newnet,img21);

img22     = imread('22.jpg');

img22     = imresize(img22,[227 227]);

predict22 = classify(newnet,img22);

img23     = imread('23.jpg');

img23     = imresize(img23,[227 227]);

predict23 = classify(newnet,img23);

 

 

 

 

 

figure;

subplot(231);

imshow(img11);

if predict11=='s01'

  title(['人脸检测结果:',nameofs01]);

elseif  predict11=='s02'

  title(['人脸检测结果:',nameofs02]);

elseif  predict11=='s03'

  title(['人脸检测结果:',nameofs03]);

end 

 

subplot(232);

imshow(img12);

if predict12=='s01'

  title(['人脸检测结果:',nameofs01]);

elseif  predict12=='s02'

  title(['人脸检测结果:',nameofs02]);

elseif  predict12=='s03'

  title(['人脸检测结果:',nameofs03]);

end 

 

 

subplot(233);

imshow(img13);

if predict13=='s01'

  title(['人脸检测结果:',nameofs01]);

elseif  predict13=='s02'

  title(['人脸检测结果:',nameofs02]);

elseif  predict13=='s03'

  title(['人脸检测结果:',nameofs03]);

end 

 

 

 

 

subplot(234);

imshow(img21);

if predict21=='s01'

  title(['人脸检测结果:',nameofs01]);

elseif  predict21=='s02'

  title(['人脸检测结果:',nameofs02]);

elseif  predict21=='s03'

  title(['人脸检测结果:',nameofs03]);

end 

 

subplot(235);

imshow(img22);

if predict22=='s01'

  title(['人脸检测结果:',nameofs01]);

elseif  predict22=='s02'

  title(['人脸检测结果:',nameofs02]);

elseif  predict22=='s03'

  title(['人脸检测结果:',nameofs03]);

end 

 

 

subplot(236);

imshow(img23);

if predict23=='s01'

  title(['人脸检测结果:',nameofs01]);

elseif  predict23=='s02'

  title(['人脸检测结果:',nameofs02]);

elseif  predict23=='s03'

  title(['人脸检测结果:',nameofs03]);

end`