Pytorch实现人脸多属性识别

415 阅读7分钟

数据来源:CalebA人脸数据集()是香港中文大学的开放数据,包含10,177个名人身份的202,599张人脸图片,并且都做好了特征标记,这对人脸相关的训练是非常好用的数据集。共计40个特征,具体是哪些特征,可以去官网查询。话不多说,直接开始流程。

整个流程可以分为大致以下几个步骤:

1.图片预处理

2.构建网络

3.训练

4.测试

5.优化

一。图片加载,以为源数据没有经过处理,我们要重写torch.utils.data.Dataloader()处理图片,然后才能将图片用于加载。代码如下:




  • def default_loader(path):
  • try:
  • img = Image.open(path)
  • return img.convert('RGB')
  • except:
  • print("Can not open {0}".format(path))
  • class myDataset(Data.DataLoader):
  • def __init__(self,img_dir,img_txt=img_txt,transform=None,loader=default_loader):
  • img_list = []
  • img_labels = []
  • fp = open(img_txt,'r')
  • for line in fp.readlines():
  • if len(line.split())!=41:
  • continue
  • img_list.append(line.split()[0])
  • img_label_single = []
  • for value in line.split()[1:]:
  • if value == '-1':
  • img_label_single.append(0)
  • if value == '1':
  • img_label_single.append(1)
  • img_labels.append(img_label_single)
  • self.imgs = [os.path.join(img_dir,file) for file in img_list]
  • self.labels = img_labels
  • self.transform = transform
  • self.loader = loader
  • def __len__(self):
  • return len(self.imgs)
  • def __getitem__(self,index):
  • img_path = self.imgs[index]
  • label = torch.from_numpy(np.array(self.labels[index],dtype=np.int64))
  • img = self.loader(img_path)
  • if self.transform is not None:
  • try:
  • img = self.transform(img)
  • except:
  • print('Cannot transform image: {}'.format(img_path))
  • return img,label


图片增强、归一化处理和加载:




  • transform = transforms.Compose([
  • transforms.Resize(40),
  • transforms.CenterCrop(32),
  • transforms.RandomHorizontalFlip(),
  • transforms.ToTensor(),
  • transforms.Normalize(mean=[0.5,0.5,0.5],
  • std = [0.5,0.5,0.5])
  • ])


[python]



  • #训练集


train_dataset = myDataset(img_dir=img_root,img_txt=train_txt,transform= transform)train_dataloader = Data.DataLoader(train_dataset,batch_size = batch_size,shuffle=True)




  • #测试集






  • test_dataset = myDataset(img_dir=img_root,img_txt = test_txt,transform= transform)
  • test_dataloader = Data.DataLoader(test_dataset,batch_size = batch_size,shuffle=True)


构建网络:我使用的网络结构是每种属性使用3层卷积加上3层fc。网络结构比较简单,导致准确率不会有太高的表现,如果有兴趣可以做下优化,文末有优化的思路供大家讨论。好了,先上代码:




  • def make_conv():
  • return nn.Sequential(
  • nn.Conv2d(3,16,3,1,1),
  • nn.ReLU(),
  • nn.MaxPool2d(2),
  • nn.Conv2d(16,32,3,1,1),
  • nn.ReLU(),
  • nn.MaxPool2d(2),
  • nn.Conv2d(32,64,3,1,1),
  • nn.ReLU(),
  • #nn.Dropout(0.5),
  • nn.MaxPool2d(2)
  • )
  • def make_fc():
  • return nn.Sequential(
  • nn.Linear(64*4*4,128),
  • nn.ReLU(),
  • #nn.Dropout(0.5),
  • nn.Linear(128,64),
  • nn.ReLU(),
  • nn.Dropout(0.5),#Dropout()可以一定程度上防止过拟合,放在不同位置或许会有意想不到的结果,有条件可以多尝试几次
  • nn.Linear(64,2)
  • )
  • class face_attr(nn.Module):
  • def __init__(self):
  • super(face_attr,self).__init__()
  • #attr0
  • self.attr0_layer1 = make_conv()
  • self.attr0_layer2 = make_fc()
  • #attr1
  • self.attr1_layer1 = make_conv()
  • self.attr1_layer2 = make_fc()
  • ...#每一中属性的计算都是相同的,在文中省略,
  • #attr38
  • self.attr38_layer1 = make_conv()
  • self.attr38_layer2 = make_fc()
  • #attr39
  • self.attr39_layer1 = make_conv()
  • self.attr39_layer2 = make_fc()
  • def forward(self,x):
  • out_list = []
  • #out0
  • out0 = self.attr0_layer1(x)
  • out0 = out0.view(out0.size(0),-1)
  • out0 = self.attr0_layer2(out0)
  • out_list.append(out0)
  • ...
  • #out39
  • out39 = self.attr39_layer1(x)
  • out39 = out39.view(out39.size(0),-1)
  • out39 = self.attr39_layer2(out39)
  • out_list.append(out39)
  • return out_list


接下来就可以开始训练网络了,定义优化器的时候可以设置一下weight_decay=1e-8,也可以在一定程度上防止过拟合。





  • module = face_attr()
  • #print(module)
  • optimizer = optim.Adam(module.parameters(),lr = 0.001,weight_decay=1e-8)
  • loss_list = []
  • for i in range(40):
  • loss_func = nn.CrossEntropyLoss()
  • loss_list.append(loss_func)
  • #loss_func = nn.CrossEntropyLoss()
  • for Epoch in range(50):
  • all_correct_num = 0
  • for ii,(img,label) in enumerate(train_dataloader):
  • img = Variable(img)
  • label = Variable(label)
  • output = module(img)
  • optimizer.zero_grad()
  • for i in range(40):
  • loss = loss_list(output,label[:,i])
  • loss.backward()
  • _,predict = torch.max(output,1)
  • correct_num = sum(predict==label[:,i])
  • all_correct_num += correct_num.data[0]
  • optimizer.step()
  • Accuracy = all_correct_num *1.0/(len(train_dataset)*40.0)
  • print('Epoch ={0},all_correct_num={1},Accuracy={2}'.format(Epoch,all_correct_num,Accuracy))
  • torch.save(module,'W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl')#每跑一个epoch就保存一次模型



测试网络:和训练类似,只是不用优化和做反向传播。





  • module = torch.load('W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl')#加载刚刚保存的网络
  • module.eval()#改成测试模式
  • all_correct_num = 0
  • for ii,(img,label) in enumerate(test_dataloader):
  • img = Variable(img)
  • label = Variable(label)
  • output = module(img)
  • for i in range(40):
  • _,predict = torch.max(output,1)
  • correct_num = sum(predict==label[:,i])
  • all_correct_num += correct_num.data[0]
  • Accuracy = all_correct_num *1.0/(len(test_dataset)*40.0)
  • print('all_correct_num={0},Accuracy={1}'.format(all_correct_num,Accuracy))



总结:我因为是笔记本电脑,没有GPU,所以只给了大约5000个数据用于训练(即使是5000个数据我的电脑也跑了2天才跑完50个epoch),1000个数据用于测试。测试的准确率在90%左右。有条件的同学可以做些优化,下面提供一些可以优化的方面:

1.图片增强,我因为电脑不给力,无法处理较大的数据,所以将原始图片缩放到40*40,然后截取了32*32作为输入,如果有GPU等条件,可以考虑128*128输入

2.最开始没加Dropout()的时候出现了过拟合的情况,当然这也和训练集较小有关系,建议训练集给到60000个样本。上述提到的Dropout()多尝试几个位置,我是放在了最后一层输出之前。期待大家尝试之后分享一下结果。

3.我现在每种属性都是使用相同的网络,但是这种网络有可能不是对于每种属性都是最优的选择,可以针对每一种属性单独写一层网络。例如:某个属性使用全fc层就可以达到很高的准确率,某个属性或许需要4层卷积+2层fc可以达到很好的效果,这种情况只有靠多尝试,算法什么的可能并不能给出哪种才是最适合的网络模型。建议输出每种属性的准确率,然后针对准确率较低的属性做相应的网络优化。

4.多准备几块GPU做训练和测试吧,没GPU真不给力。希望有小伙伴在这个网络的基础上能达到更高的准确率。

附上整体代码如下:





  • # -*- coding: utf-8 -*-
  • """
  • Created on Sun Jun 17 11:54:36 2018
  • @author: sky-hole
  • """
  • import torch
  • import torch.nn as nn
  • from torch.autograd import Variable
  • import torch.optim as optim
  • import torchvision.transforms as transforms
  • import torch.utils.data as Data
  • from PIL import Image
  • import numpy as np
  • import os
  • img_root = 'W:/pic_data/face/CelebA/Img/img_align_celeba'
  • train_txt = 'W:/pic_data/face/CelebA/Img/train10000.txt'
  • batch_size = 2
  • def default_loader(path):
  • try:
  • img = Image.open(path)
  • return img.convert('RGB')
  • except:
  • print("Can not open {0}".format(path))
  • class myDataset(Data.DataLoader):
  • def __init__(self,img_dir,img_txt,transform=None,loader=default_loader):
  • img_list = []
  • img_labels = []
  • fp = open(img_txt,'r')
  • for line in fp.readlines():
  • if len(line.split())!=41:
  • continue
  • img_list.append(line.split()[0])
  • img_label_single = []
  • for value in line.split()[1:]:
  • if value == '-1':
  • img_label_single.append(0)
  • if value == '1':
  • img_label_single.append(1)
  • img_labels.append(img_label_single)
  • self.imgs = [os.path.join(img_dir,file) for file in img_list]
  • self.labels = img_labels
  • self.transform = transform
  • self.loader = loader
  • def __len__(self):
  • return len(self.imgs)
  • def __getitem__(self,index):
  • img_path = self.imgs[index]
  • label = torch.from_numpy(np.array(self.labels[index],dtype=np.int64))
  • img = self.loader(img_path)
  • if self.transform is not None:
  • try:
  • img = self.transform(img)
  • except:
  • print('Cannot transform image: {}'.format(img_path))
  • return img,label
  • transform = transforms.Compose([
  • transforms.Resize(40),
  • transforms.CenterCrop(32),
  • transforms.RandomHorizontalFlip(),
  • transforms.ToTensor(),
  • transforms.Normalize(mean=[0.5,0.5,0.5],
  • std = [0.5,0.5,0.5])
  • ])
  • train_dataset = myDataset(img_dir=img_root,img_txt=train_txt,transform= transform)
  • train_dataloader = Data.DataLoader(train_dataset,batch_size = batch_size,shuffle=True)
  • #print(len(train_dataset))
  • #print(len(train_dataloader))
  • def make_conv():
  • return nn.Sequential(
  • nn.Conv2d(3,16,3,1,1),
  • nn.ReLU(),
  • nn.MaxPool2d(2),
  • nn.Conv2d(16,32,3,1,1),
  • nn.ReLU(),
  • nn.MaxPool2d(2),
  • nn.Conv2d(32,64,3,1,1),
  • nn.ReLU(),
  • #nn.Dropout(0.5),
  • nn.MaxPool2d(2)
  • )
  • def make_fc():
  • return nn.Sequential(
  • nn.Linear(64*4*4,128),
  • nn.ReLU(),
  • #nn.Dropout(0.5),
  • nn.Linear(128,64),
  • nn.ReLU(),
  • nn.Dropout(0.5),
  • nn.Linear(64,2)
  • )
  • class face_attr(nn.Module):
  • def __init__(self):
  • super(face_attr,self).__init__()
  • #attr0
  • self.attr0_layer1 = make_conv()
  • self.attr0_layer2 = make_fc()
  • #attr1
  • self.attr1_layer1 = make_conv()
  • self.attr1_layer2 = make_fc()
  • #attr2
  • self.attr2_layer1 = make_conv()
  • self.attr2_layer2 = make_fc()
  • #attr3
  • self.attr3_layer1 = make_conv()
  • self.attr3_layer2 = make_fc()
  • #attr4
  • self.attr4_layer1 = make_conv()
  • self.attr4_layer2 = make_fc()
  • #attr5
  • self.attr5_layer1 = make_conv()
  • self.attr5_layer2 = make_fc()
  • #attr6
  • self.attr6_layer1 = make_conv()
  • self.attr6_layer2 = make_fc()
  • #attr7
  • self.attr7_layer1 = make_conv()
  • self.attr7_layer2 = make_fc()
  • #attr8
  • self.attr8_layer1 = make_conv()
  • self.attr8_layer2 = make_fc()
  • #attr9
  • self.attr9_layer1 = make_conv()
  • self.attr9_layer2 = make_fc()
  • #attr10
  • self.attr10_layer1 = make_conv()
  • self.attr10_layer2 = make_fc()
  • #attr11
  • self.attr11_layer1 = make_conv()
  • self.attr11_layer2 = make_fc()
  • #attr12
  • self.attr12_layer1 = make_conv()
  • self.attr12_layer2 = make_fc()
  • #attr13
  • self.attr13_layer1 = make_conv()
  • self.attr13_layer2 = make_fc()
  • #attr14
  • self.attr14_layer1 = make_conv()
  • self.attr14_layer2 = make_fc()
  • #attr15
  • self.attr15_layer1 = make_conv()
  • self.attr15_layer2 = make_fc()
  • #attr16
  • self.attr16_layer1 = make_conv()
  • self.attr16_layer2 = make_fc()
  • #attr17
  • self.attr17_layer1 = make_conv()
  • self.attr17_layer2 = make_fc()
  • #attr18
  • self.attr18_layer1 = make_conv()
  • self.attr18_layer2 = make_fc()
  • #attr19
  • self.attr19_layer1 = make_conv()
  • self.attr19_layer2 = make_fc()
  • #attr20
  • self.attr20_layer1 = make_conv()
  • self.attr20_layer2 = make_fc()
  • #attr21
  • self.attr21_layer1 = make_conv()
  • self.attr21_layer2 = make_fc()
  • #attr22
  • self.attr22_layer1 = make_conv()
  • self.attr22_layer2 = make_fc()
  • #attr23
  • self.attr23_layer1 = make_conv()
  • self.attr23_layer2 = make_fc()
  • #attr24
  • self.attr24_layer1 = make_conv()
  • self.attr24_layer2 = make_fc()
  • #attr25
  • self.attr25_layer1 = make_conv()
  • self.attr25_layer2 = make_fc()
  • #attr26
  • self.attr26_layer1 = make_conv()
  • self.attr26_layer2 = make_fc()
  • #attr27
  • self.attr27_layer1 = make_conv()
  • self.attr27_layer2 = make_fc()
  • #attr28
  • self.attr28_layer1 = make_conv()
  • self.attr28_layer2 = make_fc()
  • #attr29
  • self.attr29_layer1 = make_conv()
  • self.attr29_layer2 = make_fc()
  • #attr30
  • self.attr30_layer1 = make_conv()
  • self.attr30_layer2 = make_fc()
  • #attr31
  • self.attr31_layer1 = make_conv()
  • self.attr31_layer2 = make_fc()
  • #attr32
  • self.attr32_layer1 = make_conv()
  • self.attr32_layer2 = make_fc()
  • #attr33
  • self.attr33_layer1 = make_conv()
  • self.attr33_layer2 = make_fc()
  • #attr34
  • self.attr34_layer1 = make_conv()
  • self.attr34_layer2 = make_fc()
  • #attr35
  • self.attr35_layer1 = make_conv()
  • self.attr35_layer2 = make_fc()
  • #attr36
  • self.attr36_layer1 = make_conv()
  • self.attr36_layer2 = make_fc()
  • #attr37
  • self.attr37_layer1 = make_conv()
  • self.attr37_layer2 = make_fc()
  • #attr38
  • self.attr38_layer1 = make_conv()
  • self.attr38_layer2 = make_fc()
  • #attr39
  • self.attr39_layer1 = make_conv()
  • self.attr39_layer2 = make_fc()
  • def forward(self,x):
  • out_list = []
  • #out0
  • out0 = self.attr0_layer1(x)
  • out0 = out0.view(out0.size(0),-1)
  • out0 = self.attr0_layer2(out0)
  • out_list.append(out0)
  • #out1
  • out1 = self.attr1_layer1(x)
  • out1 = out1.view(out1.size(0),-1)
  • out1 = self.attr1_layer2(out1)
  • out_list.append(out1)
  • #out2
  • out2 = self.attr2_layer1(x)
  • out2 = out2.view(out2.size(0),-1)
  • out2 = self.attr2_layer2(out2)
  • out_list.append(out2)
  • #out3
  • out3 = self.attr3_layer1(x)
  • out3 = out3.view(out3.size(0),-1)
  • out3 = self.attr3_layer2(out3)
  • out_list.append(out3)
  • #out4
  • out4 = self.attr4_layer1(x)
  • out4 = out4.view(out4.size(0),-1)
  • out4 = self.attr4_layer2(out4)
  • out_list.append(out4)
  • #out5
  • out5 = self.attr5_layer1(x)
  • out5 = out5.view(out5.size(0),-1)
  • out5 = self.attr5_layer2(out5)
  • out_list.append(out5)
  • #out6
  • out6 = self.attr6_layer1(x)
  • out6 = out6.view(out6.size(0),-1)
  • out6 = self.attr6_layer2(out6)
  • out_list.append(out6)
  • #out7
  • out7 = self.attr7_layer1(x)
  • out7 = out7.view(out7.size(0),-1)
  • out7 = self.attr7_layer2(out7)
  • out_list.append(out7)
  • #out8
  • out8 = self.attr8_layer1(x)
  • out8 = out8.view(out8.size(0),-1)
  • out8 = self.attr8_layer2(out8)
  • out_list.append(out8)
  • #out9
  • out9 = self.attr9_layer1(x)
  • out9 = out9.view(out9.size(0),-1)
  • out9 = self.attr9_layer2(out9)
  • out_list.append(out9)
  • #out10
  • out10 = self.attr10_layer1(x)
  • out10 = out10.view(out10.size(0),-1)
  • out10 = self.attr10_layer2(out10)
  • out_list.append(out10)
  • #out11
  • out11 = self.attr11_layer1(x)
  • out11 = out11.view(out11.size(0),-1)
  • out11 = self.attr11_layer2(out11)
  • out_list.append(out11)
  • #out12
  • out12 = self.attr12_layer1(x)
  • out12 = out12.view(out12.size(0),-1)
  • out12 = self.attr12_layer2(out12)
  • out_list.append(out12)
  • #out13
  • out13 = self.attr13_layer1(x)
  • out13 = out13.view(out13.size(0),-1)
  • out13 = self.attr13_layer2(out13)
  • out_list.append(out13)
  • #out14
  • out14 = self.attr14_layer1(x)
  • out14 = out14.view(out14.size(0),-1)
  • out14 = self.attr14_layer2(out14)
  • out_list.append(out14)
  • #out15
  • out15 = self.attr15_layer1(x)
  • out15 = out15.view(out15.size(0),-1)
  • out15 = self.attr15_layer2(out15)
  • out_list.append(out15)
  • #out16
  • out16 = self.attr16_layer1(x)
  • out16 = out16.view(out16.size(0),-1)
  • out16 = self.attr16_layer2(out16)
  • out_list.append(out16)
  • #out17
  • out17 = self.attr17_layer1(x)
  • out17 = out17.view(out17.size(0),-1)
  • out17 = self.attr17_layer2(out17)
  • out_list.append(out17)
  • #out18
  • out18 = self.attr18_layer1(x)
  • out18 = out18.view(out18.size(0),-1)
  • out18 = self.attr18_layer2(out18)
  • out_list.append(out18)
  • #out19
  • out19 = self.attr19_layer1(x)
  • out19 = out19.view(out19.size(0),-1)
  • out19 = self.attr19_layer2(out19)
  • out_list.append(out19)
  • #out20
  • out20 = self.attr20_layer1(x)
  • out20 = out20.view(out20.size(0),-1)
  • out20 = self.attr20_layer2(out20)
  • out_list.append(out20)
  • #out21
  • out21 = self.attr21_layer1(x)
  • out21 = out21.view(out21.size(0),-1)
  • out21 = self.attr21_layer2(out21)
  • out_list.append(out21)
  • #out22
  • out22 = self.attr22_layer1(x)
  • out22 = out22.view(out22.size(0),-1)
  • out22 = self.attr22_layer2(out22)
  • out_list.append(out22)
  • #out23
  • out23 = self.attr23_layer1(x)
  • out23 = out23.view(out23.size(0),-1)
  • out23 = self.attr23_layer2(out23)
  • out_list.append(out23)
  • #out24
  • out24 = self.attr24_layer1(x)
  • out24 = out24.view(out24.size(0),-1)
  • out24 = self.attr24_layer2(out24)
  • out_list.append(out24)
  • #out25
  • out25 = self.attr25_layer1(x)
  • out25 = out25.view(out25.size(0),-1)
  • out25 = self.attr25_layer2(out25)
  • out_list.append(out25)
  • #out26
  • out26 = self.attr26_layer1(x)
  • out26 = out26.view(out26.size(0),-1)
  • out26 = self.attr26_layer2(out26)
  • out_list.append(out26)
  • #out27
  • out27 = self.attr27_layer1(x)
  • out27 = out27.view(out27.size(0),-1)
  • out27 = self.attr27_layer2(out27)
  • out_list.append(out27)
  • #out28
  • out28 = self.attr28_layer1(x)
  • out28 = out28.view(out28.size(0),-1)
  • out28 = self.attr28_layer2(out28)
  • out_list.append(out28)
  • #out29
  • out29 = self.attr29_layer1(x)
  • out29 = out29.view(out29.size(0),-1)
  • out29 = self.attr29_layer2(out29)
  • out_list.append(out29)
  • #out30
  • out30 = self.attr30_layer1(x)
  • out30 = out30.view(out30.size(0),-1)
  • out30 = self.attr30_layer2(out30)
  • out_list.append(out30)
  • #out31
  • out31 = self.attr31_layer1(x)
  • out31 = out31.view(out31.size(0),-1)
  • out31 = self.attr31_layer2(out31)
  • out_list.append(out31)
  • #out32
  • out32 = self.attr32_layer1(x)
  • out32 = out32.view(out32.size(0),-1)
  • out32 = self.attr32_layer2(out32)
  • out_list.append(out32)
  • #out33
  • out33 = self.attr33_layer1(x)
  • out33 = out33.view(out33.size(0),-1)
  • out33 = self.attr33_layer2(out33)
  • out_list.append(out33)
  • #out34
  • out34 = self.attr34_layer1(x)
  • out34 = out34.view(out34.size(0),-1)
  • out34 = self.attr34_layer2(out34)
  • out_list.append(out34)
  • #out35
  • out35 = self.attr35_layer1(x)
  • out35 = out35.view(out35.size(0),-1)
  • out35 = self.attr35_layer2(out35)
  • out_list.append(out35)
  • #out36
  • out36 = self.attr36_layer1(x)
  • out36 = out36.view(out36.size(0),-1)
  • out36 = self.attr36_layer2(out36)
  • out_list.append(out36)
  • #out37
  • out37 = self.attr37_layer1(x)
  • out37 = out37.view(out37.size(0),-1)
  • out37 = self.attr37_layer2(out37)
  • out_list.append(out37)
  • #out38
  • out38 = self.attr38_layer1(x)
  • out38 = out38.view(out38.size(0),-1)
  • out38 = self.attr38_layer2(out38)
  • out_list.append(out38)
  • #out39
  • out39 = self.attr39_layer1(x)
  • out39 = out39.view(out39.size(0),-1)
  • out39 = self.attr39_layer2(out39)
  • out_list.append(out39)
  • return out_list
  • module = face_attr()
  • #print(module)
  • optimizer = optim.Adam(module.parameters(),lr = 0.001,weight_decay=1e-8)
  • loss_list = []
  • for i in range(40):
  • loss_func = nn.CrossEntropyLoss()
  • loss_list.append(loss_func)
  • #loss_func = nn.CrossEntropyLoss()
  • for Epoch in range(50):
  • all_correct_num = 0
  • for ii,(img,label) in enumerate(train_dataloader):
  • img = Variable(img)
  • label = Variable(label)
  • # optimizer.zero_grad()
  • output = module(img)
  • optimizer.zero_grad()
  • for i in range(40):
  • loss = loss_list(output,label[:,i])
  • loss.backward()
  • _,predict = torch.max(output,1)
  • correct_num = sum(predict==label[:,i])
  • all_correct_num += correct_num.data[0]
  • optimizer.step()
  • Accuracy = all_correct_num *1.0/(len(train_dataset)*40.0)
  • print('Epoch ={0},all_correct_num={1},Accuracy={2}'.format(Epoch,all_correct_num,Accuracy))
  • torch.save(module,'W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl')
  • '''''
  • test_txt = 'W:/pic_data/face/CelebA/Img/test1000.txt'
  • test_dataset = myDataset(img_dir=img_root,img_txt = test_txt,transform= transform)
  • test_dataloader = Data.DataLoader(test_dataset,batch_size = batch_size,shuffle=True)
  • module = torch.load('W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl')
  • module.eval()
  • all_correct_num = 0
  • for ii,(img,label) in enumerate(test_dataloader):
  • img = Variable(img)
  • label = Variable(label)
  • output = module(img)
  • for i in range(40):
  • _,predict = torch.max(output,1)
  • correct_num = sum(predict==label[:,i])
  • all_correct_num += correct_num.data[0]
  • Accuracy = all_correct_num *1.0/(len(test_dataset)*40.0)
  • print('all_correct_num={0},Accuracy={1}'.format(all_correct_num,Accuracy))
  • '''

更多免费技术资料可关注:annalin1203