公式识别训练问题

1,479 阅读1分钟

1.训练中断

Loss: 0.4968:  88%|████████████████████▎  | 17173/19457 [44:41<05:56,  6.40it/s]
Traceback (most recent call last):
  File "/media/newData/user/dxq/LaTeX-OCR/train.py", line 97, in <module>
    train(args)
  File "/media/newData/user/dxq/LaTeX-OCR/train.py", line 53, in train
    loss.backward()
  File "/home/appuser/.local/lib/python3.6/site-packages/torch/tensor.py", line 245, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/appuser/.local/lib/python3.6/site-packages/torch/autograd/__init__.py", line 147, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: transform: failed to synchronize: cudaErrorAssert: device-side assert triggered

原因:token数量超出参数限制

解决:计算最大token数量,重新设置

2. 无法训练

Traceback (most recent call last):
  File "/media/newData/user/dxq/LaTeX-OCR/train.py", line 100, in <module>
    train(args)
  File "/media/newData/user/dxq/LaTeX-OCR/train.py", line 54, in train
    encoded = encoder(im.to(device))
  File "/home/appuser/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/appuser/.local/lib/python3.6/site-packages/timm/models/vision_transformer.py", line 374, in forward
    x = self.forward_features(x)
  File "/media/newData/user/dxq/LaTeX-OCR/models.py", line 80, in forward_features
    x += self.pos_embed[:, pos_emb_ind]
RuntimeError: The size of tensor a (25) must match the size of tensor b (12) at non-singleton dimension 1

出现问题的代码在这里

    def forward_features(self, x):
        B, c, h_init, w_init = x.shape
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)
        total_channel =x.shape[1]
        h, w = math.ceil(h_init/self.patch_size), math.ceil(w_init/self.patch_size)
        # if h*w+1 != total_channel:
        #     h, w = w_init // self.patch_size,w_init // self.patch_size

        pos_emb_ind = repeat(torch.arange(h)*(self.width//self.patch_size-w), 'h -> (h w)', w=w)+torch.arange(h*w)
        pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()
        pos_embed = self.pos_embed
        x += self.pos_embed[:, pos_emb_ind]
        # x = x + self.pos_embed
        x = self.pos_drop(x)
        # torch.cat((torch.zeros(1),repeat(torch.arange(h) * (self.width // self.patch_size - w), 'h -> (h w)', w=w) + torch.arange(h * w)+1), dim=0).long()
        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return x

生成的维度不一致,应该向上取整

 h, w = w_init // self.patch_size,w_init // self.patch_size

变成

h, w = math.ceil(h_init/self.patch_size), math.ceil(w_init/self.patch_size)

心塞,疯狂debug

bug很简单,解决起来也很简单,难在定位到这个bug的位置

5ef16cabf860183c562f197ebb3bd5f.png