主要利用和matplotlib和imageio这两个python库。由于matplotlib无法直接得到所绘图的RGB值,所以每次画完一帧图后,保存下来再读取得到每一帧的RGB值,最后使用imageio将所有的帧连接起来组合成一个动图。这种方法是很多生成动图的方法中较为简单的一种,但是因为每次都要保存和读取图片,所以会增加一定的程序耗时。
下面是使用pytorch写的一个简单的线性回归的例子:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import imageio
torch.manual_seed(0)
num_samples = 100
x_train = torch.linspace(0, 1, num_samples)
y_train = 0.1 * x_train + 0.2 + torch.randn(num_samples)*0.03
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD([w,b], lr=0.01)
images = []
num_epochs = 4000
for epoch in range(num_epochs):
y_pred = w * x_train + b
loss = criterion(y_pred, y_train)
optimizer.zero_grad()
loss.backward()
if epoch % 100 == 99:
plt.figure()
plt.ylim(torch.min(y_train).item(), torch.max(y_train).item())
plt.scatter(x_train.tolist(), y_train.tolist(), marker='.')
plt.plot(x_train.tolist(), y_pred.tolist(), color='r', linewidth=2)
plt.title('Epoch [{}/{}], Loss: {:.6f}, \n Weight: {:.6f}, Bias: {:.6f}'
.format(epoch+1, num_epochs, loss.item(), w.item(), b.item()))
plt.savefig('a.png')
plt.close()
images.append(imageio.imread('a.png'))
optimizer.step()
imageio.mimsave('gen.gif', images, duration=0.5)
最后生成的动图如下:
