RuntimeError: Skipping the `training_step` by returning None

150 阅读1分钟

在复现模型中遇到了如下问题: RuntimeError: Skipping the training_step by returning None in distributed training is not supported. It is recommended that you rewrite your training logic to avoid having to skip the step in the first place.

因为代码中用到了PyTorch Lightning,但是当train_a_step()返回None时,最终会导致PyTorch Lightning报错。 具体代码如下:

    def training_step(self, batch, batch_idx):
        x, y, scanner, filepath = batch
        
        # print('VOLTEI AO training_step')

        self.grammatrices = []

        self.stylemodel.eval()

        if self.mparams.continuous:
            # save checkpoint at scanner shift
            newshift = False
            for s in scanner:
                if s != self.mparams.order[0] and not self.scanner_checkpoints[s]:
                    newshift = True
                    shift_scanner = s
            if newshift:
                # print('new shift to', shift_scanner)
                exp_name = utils.get_expname(self.mparams)
                weights_path = self.modeldir + exp_name + '_shift_' + shift_scanner + '.pt'
                torch.save(self.model.state_dict(), weights_path)
                self.scanner_checkpoints[shift_scanner] = True
        
        # print('training_step1')

        if self.mparams.use_memory:
            #y = y[:, None]
            self.grammatrices = []
            if type(x) is list or type(x) is tuple:
                xstyle = torch.stack(x)
            elif x.size()[1] == 1 and self.mparams.dim != 3:
                xstyle = torch.cat([x, x, x], dim=1)
            else:
                xstyle = x
            _ = self.stylemodel(xstyle)

            train_a_step = self.insert_element(x, y, filepath, scanner)
            print(f'train_a_step: {{}}',train_a_step)


            if train_a_step:
                x, y = self.trainingsmemory.get_training_batch(self.mparams.batch_size,batches=int(                                                                   self.mparams.training_batch_size / self.mparams.batch_size))
                self.train_counter += 1
                
            else:
                return None

起初以为是return None的问题后面发现,报错中提及触发条件为在分布式训练(DDP)模式下运行时,联想到在运行过程中控制台输出检测到本机有多个GPU(2080+4080)因此默认触发了ddp模式,但混合使用不同型号 GPU(如 RTX 4080 和 2080)进行 DDP 分布式训练时,可能会出现架构差异引起的可能导致计算速度不匹配,引发同步超时错误

后面在trainer部分进行了修改,如下:

trainer = Trainer(
            max_epochs=EPOCHS,
            accelerator='gpu' if torch.cuda.is_available() else None,
            devices=1,
            logger=logger,
            val_check_interval=model.mparams.val_check_interval,
            gradient_clip_val=model.mparams.gradient_clip_val,
            strategy='auto',
            # strategy='ddp_find_unused_parameters_true'  # Add this line to enable detection of unused parameters in DDP
        )

添加了devices=1,并且取消了strategy='ddp_find_unused_parameters_true'

后面成功解决。