【pytorch踩坑】解决训练GCN-GAN过程中的报错

325 阅读2分钟

跑了一个GitHub上的开源代码:github.com/jiangqn/GCN… ,训练过程出现两个错误。

错误一

报错信息:Traceback (most recent call last): File "..\GCN-GAN-pytorch\test_data.py", line 31, in <module> config = yaml.load(open('config.yml')) TypeError: load() missing 1 required positional argument: 'Loader'

原因分析: 我使用的pyyaml是6.0版本,而在YAML5.1版本后弃用了yaml.load(file)这个用法,因为觉得很不安全,5.1版本之后就修改了需要指定Loader,通过默认加载器(FullLoader)禁止执行任意函数,该load函数也变得更加安全

解决办法: 通过以下三种方法都能解决问题

d1=yaml.load(file,Loader=yaml.FullLoader)  
d1=yaml.safe_load(file)  
d1 = yaml.load(file, Loader=yaml.CLoader)

方法转载于:TypeError: load() missing 1 required positional argument: ‘Loader‘_one_wangtester的博客-CSDN博客

错误二

报错信息:RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

原因分析: 可能是由于Pytorch的版本不同,内置的BACKWARD的流程发生了变化,我的Pytorch版本是1.13.1, 源代码Pytorch版本不详

报错部分代码:

for epoch in range(config['gan_epoches']):
    for i, (data, sample) in enumerate(zip(train_loader, sample_loader)):
        # update discriminator
        discriminator_optimizer.zero_grad()
        generator_optimizer.zero_grad()
        in_shots, out_shot = data
        in_shots, out_shot = in_shots.cuda(), out_shot.cuda()
        predicted_shot = generator(in_shots)
        _, sample = sample
        sample = sample.cuda()
        sample = sample.view(config['batch_size'], -1)
        real_logit = discriminator(sample).mean()
        fake_logit = discriminator(predicted_shot).mean()
        discriminator_loss = -real_logit + fake_logit
        discriminator_loss.backward(retain_graph=True)
        discriminator_optimizer.step()
        for p in discriminator.parameters():
            p.data.clamp_(-config['weight_clip'], config['weight_clip'])
        # update generator
        generator_loss = -fake_logit
        generator_loss.backward()
        generator_optimizer.step()
        out_shot = out_shot.view(config['batch_size'], -1)
        mse_loss = mse(predicted_shot, out_shot)
        print('[epoch %d] [step %d] [d_loss %.4f] [g_loss %.4f] [mse_loss %.4f]' % (epoch, i,
                discriminator_loss.item(), generator_loss.item(), mse_loss.item()))

修改生成网络与判别网络的更新流程,修改后代码为:

for epoch in range(config['gan_epoches']):
    for i, (data, sample) in enumerate(zip(train_loader, sample_loader)):
        in_shots, out_shot = data
        in_shots, out_shot = in_shots.cuda(), out_shot.cuda()
        predicted_shot = generator(in_shots)
        _, sample = sample
        sample = sample.cuda()
        sample = sample.view(config['batch_size'], -1)
        predicted_shot = predicted_shot.squeeze(0)
        # update generator
        fake_logit = discriminator(predicted_shot)
        generator_loss = -fake_logit.mean()
        # fake_logit = discriminator(predicted_shot).mean()
        # generator_loss = -fake_logit
        generator_optimizer.zero_grad()
        generator_loss.backward()
        generator_optimizer.step()
        # update discriminator
        real_logit = discriminator(sample)
        fake_logit = discriminator(predicted_shot.detach())
        discriminator_loss = -real_logit.mean() + fake_logit.mean()
        # real_logit = discriminator(sample).mean()
        # discriminator_loss = -real_logit + fake_logit
        discriminator_optimizer.zero_grad()
        discriminator_loss.backward(retain_graph=True)
        discriminator_optimizer.step()
        for p in discriminator.parameters():
            p.data.clamp_(-config['weight_clip'], config['weight_clip'])
        out_shot = out_shot.view(config['batch_size'], -1)
        mse_loss = mse(predicted_shot, out_shot)
        print('[epoch %d] [step %d] [d_loss %.4f] [g_loss %.4f] [mse_loss %.4f]' % (epoch, i,
                discriminator_loss.item(), generator_loss.item(), mse_loss.item()))

最后成功运行

image.png

参考文章:解决一个GAN训练过程中的报错:one of the variables needed for gradient computation has been modified by an inplace_one of the variables needed of-CSDN博客