携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第30天,点击查看活动详情
训练模型
接下来就开始定义优化器,这次采用算法更新模型参数,这里
- 每一次将一个批次数据喂给模型
- 通过对比真实标签和模型预测来计算损失
- 根据损失函数计算梯度
- 学习率乘以梯度的模来更新
这里采用 Adam 优化器,采用优化器默认参数来更新模型的参数。可以通过尝试不同优化器和学习率来找到合适的参数更新的策略。关于 Adam 这个优化器的特点在之前视频有过详细介绍
optimizer = optim.Adam(model.parameters())
我们定义目标,PyTorch 提供各种当下流行损失函数。这个函数。关于什么是目标函数(损失函数)相比大家比较清除
采用 crossEntropyLoss 包含两个部分第一个部分对预测结果做一个 softmax 将输出值变为一个概率分布,然后在通过 negative log likelihood
神经网络输出 是一个 10 维向量,而 y 是一个 label 用一个整数表示类别,这个损失函数,也就是通过 y 对应在 索引位置的值,这个值是模型给概率预测概率值。
然后 softmax
output = torch.ones(10)
output[0] = 5
hat_y = torch.softmax(output,dim=0)
hat_y
tensor([0.8585, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157,
0.0157])
y = 0
-torch.log(hat_y[y])
tensor(0.1526)
y = 5
-torch.log(hat_y[y])
tensor(4.1526)
criterion = nn.CrossEntropyLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = criterion.to(device)
接下来定义函数用来计算模型准确性,也就是将预测出概率最高对应索引和真实标签索引进行对比较,如果相等则计数一次正确预测,然后除以批量样本数来计算准确性。
def calculate_accuracy(y_pred,y):
top_pred = y_pred.argmax(1,keepdim=True)
correct = top_pred.eq(y.view_as(top_pred)).sum()
acc = correct.float() / y.shape[0]
return acc
开始训练
- 将模型切换到
train也就是训练模式 - 通过 dataloader 来加载数据
- 清空上一批次计算的梯度
- 将一批次的图像数据传入模型来预测
- 计算预测值和真实值之间损失
- 根据预测结果计算模型准确性
- 计算梯度
- 更新参数
- 更新度量
model 在处于不同模式,例如 train 和 valid 一些层的机制是根据模型这个状态而不同,所以需要显示告诉模型现在所处模式是训练模式
def train(model, iterator, optimizer, criterion, device):
epoch_loss = 0
epoch_acc = 0
model.train()
for (x,y) in tqdm(iterator, desc="Training", leave=False):
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
y_pred,_ = model(x)
loss = criterion(y_pred,y)
acc = calculate_accuracy(y_pred,y)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator),epoch_acc /len(iterator)
评估
- 将模型切换为
model.eval()评估模式 - 所有逻辑都包裹在
torch.no_grad()这个方法 - 不会再去计算梯度
- 在循环中也不会通过优化器来更新参数
def evaluate(model, iterator, criterion, device):
epoch_loss = 0
epoch_acc = 0
model.eval()
with torch.no_grad():
for (x,y) in tqdm(iterator, desc="Training", leave=False):
x = x.to(device)
y = y.to(device)
y_pred,_ = model(x)
loss = criterion(y_pred,y)
acc = calculate_accuracy(y_pred,y)
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator),epoch_acc /len(iterator)
计算一个 epoch 的时间
def epoch_time(start_time, end_time):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins,elapsed_secs
开始训练
EPOCHS = 5
best_valid_loss = float('inf')
for epoch in trange(EPOCHS):
start_time = time.monotonic()
train_loss, train_acc = train(model,train_iterator,optimizer,criterion,device)
valid_loss, valid_acc = evaluate(model,valid_iterator,criterion,device)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(),'tut1_model.pt')
end_time = time.monotonic()
epoch_mins, epoch_secs = epoch_time(start_time,end_time)
print(f"Epoch: {epoch+1:02} | Epoch Time: {epoch_mins} m {epoch_secs}s")
print(f'\tTrain Loss:{train_loss:.3f} | Train Acc: {train_acc * 100:.2f}%')
print(f'\t Val. Loss:{valid_loss:.3f} | Val. Acc: {valid_acc * 100:.2f}%')
加载训练好的模型的参数,
model.load_state_dict(torch.load('tut1_model.pt'))
test_loss, test_acc = evaluate(model,test_iterator,criterion,device)
print(f"Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%")
Test Loss: 0.063 | Test Acc: 97.91%
评估模型
现在我们已经训练好了一个模型,接下里我们进一步分析模型,首先我们观察一些那些被模型分类错误样本,看一看这些样本是不是对于我们来说也是比较难于识别,也就是看一看是不是模型因为这些样本本身难于识别才做出错误判断
def get_predictions(model, iterator, device):
model.eval()
images = []
labels = []
probs = []
with torch.no_grad():
for (x,y) in iterator:
x = x.to(device)
y_pred,_ = model(x)
y_prob = F.softmax(y_pred,dim=-1)
images.append(x.cpu())
labels.append(y.cpu())
probs.append(y_prob.cpu())
images = torch.cat(images,dim=0)
labels = torch.cat(labels,dim=0)
probs = torch.cat(probs,dim=0)
return images, labels, probs
# 将获取预测结果,然后跟给出最高预测概率对应索引,获取预测label
images, labels, probs = get_predictions(model,test_iterator,device)
pred_labels = torch.argmax(probs,1)
#根据真实 label 和 预测 label 绘制混淆矩阵
def plot_confusion_matrix(labels,pred_labels):
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1,1,1)
cm = metrics.confusion_matrix(labels,pred_labels)
cm = metrics.ConfusionMatrixDisplay(cm,display_labels=range(10))
cm.plot(values_format='d',cmap='Blues',ax=ax)
plot_confusion_matrix(labels,pred_labels)
#接下里查看预测和真实预测结果相同,也就是分类正确的样本
corrects = torch.eq(labels,pred_labels)
incorrect_examples = []
for image, label, prob, correct in zip(images,labels,probs,corrects):
if not correct:
incorrect_examples.append((image,label,prob))
incorrect_examples.sort(reverse=True,key=lambda x: torch.max(x[2],dim=0).values)
#接下来我们来分析一下那些图像分类错误的样本
def plot_most_incorrect(incorrect, n_images):
rows = int(np.sqrt(n_images))
cols = int(np.sqrt(n_images))
fig = plt.figure(figsize=(20,10))
for i in range(rows * cols):
ax = fig.add_subplot(rows, cols, i +1)
image,true_label,probs = incorrect[i]
true_prob = probs[true_label]
incorrect_prob, incorrect_label = torch.max(probs,dim=0)
ax.imshow(image.view(28,28).cpu().numpy(),cmap='bone')
ax.set_title(f'true label:{true_label}({true_prob:.3f})\n'
f'pred_labels:{incorrect_label}({incorrect_prob:.3f})')
ax.axis('off')
fig.subplots_adjust(hspace=0.5)
N_IMAGES = 25
plot_most_incorrect(incorrect_examples,N_IMAGES)
上面图列出模型识别错误 25 图像,其中 true_label 表示为真实标签下模型给出该概率,然后 pred_label 是模型给出答案,并且后面是模型给出概率的概率
接下来我们要做的事就是看一看模型在识别过程中都看到什么样图像,从模型中获取输出,并且看一看第二隐藏层模型提取特征图
def get_representations(model,iterator,device):
model.eval()
outputs = []
intermediates = []
labels = []
with torch.no_grad():
for(x,y) in tqdm(iterator):
x = x.to(device)
y_pred,h = model(x)
outputs.append(y_pred.cpu())
intermediates.append(h.cpu())
labels.append(y)
outputs = torch.cat(outputs,dim=0)
intermediates = torch.cat(intermediates,dim=0)
labels = torch.cat(labels,dim=0)
return outputs, intermediates, labels
outputs, intermediates,labels = get_representations(model,train_iterator,device)
0%| | 0/844 [00:00<?, ?it/s]
通过降维可以将 10 维或者 100 维数据压缩到我们可以理解维度 2 维,然后将其进行绘制,这里采用降维技术维 PCA
def get_pca(data, n_components=2):
pca = decomposition.PCA()
pca.n_components = n_components
pca_data = pca.fit_transform(data)
return pca_data
def plot_representations(data,labels,n_images=None):
if n_images is not None:
data = data[:n_images]
labels = labels[:n_images]
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111)
scatter = ax.scatter(data[:,0],data[:,1],c=labels,cmap='tab10')
handles, labels = scatter.legend_elements()
ax.legend(handles=handles,labels=labels)
output_pca_data = get_pca(outputs)
plot_representations(output_pca_data,labels)
intermedia_pca_data = get_pca(intermediates)
plot_representations(intermedia_pca_data,labels)
def get_tsne(data,n_components=2,n_images=None):
if n_images is not None:
data = data[:n_images]
tsne = manifold.TSNE(n_components=n_components,random_state=0)
tsne_data = tsne.fit_transform(data)
return tsne_data
N_IMAGES = 5_000
output_tsne_data = get_tsne(outputs,n_images=N_IMAGES)
plot_representations(output_tsne_data,labels,n_images=N_IMAGES)
intermediate_tsne_data = get_tsne(intermediates,n_images=N_IMAGES)
plot_representations(intermediate_tsne_data,labels,n_images=N_IMAGES)
可以尝试做一个实验,随机生成一些噪音,然后通过模型,如果模型反馈为图像时,我们需要回来看一看这张噪音图像
def imagine_digit(model,digit,device,n_iterations=50_000):
model.eval()
best_prob = 0
best_image = None
with torch.no_grad():
for _ in trange(n_iterations):
x = torch.randn(32,28,28).to(device)
y_pred,_ = model(x)
preds = F.softmax(y_pred,dim=-1)
_best_prob, index = torch.max(preds[:,digit],dim=0)
if _best_prob > best_prob:
best_prob = _best_prob
best_image = x[index]
return best_image,best_prob
DIGIT = 3
best_image,best_prob = imagine_digit(model,DIGIT,device)
0%| | 0/50000 [00:00<?, ?it/s]
print(f"Best image probability: {best_prob.item()*100:.2f}")
Best image probability: 99.98
plt.imshow(best_image.cpu().numpy(),cmap='bone')
plt.axis('off')
(-0.5, 27.5, 27.5, -0.5)
模型以很高置信度,也就是概率认为上面这张图像是 3 ,对于我们来看,这就是一些杂乱无章像素组成的图像,可能是模型过拟合。我们可以将模型的一层权重绘制出来,看看在第一层是否从学到到了什么模式
def plot_weights(weights,n_weights):
rows = int(np.sqrt(n_weights))
cols = int(np.sqrt(n_weights))
fig = plt.figure(figsize=(20,10))
for i in range(rows*cols):
ax = fig.add_subplot(rows,cols,i+1)
ax.imshow(weights[i].view(28,28).cpu().numpy(),cmap='bone')
ax.axis('off')
N_WEIGHTS = 25
weights = model.input_fc.weight.data
plot_weights(weights,N_WEIGHTS)