如何解决视图大小与输入张量的大小和跨度不兼容

82 阅读1分钟

RuntimeError:视图大小与输入张量的大小和跨度不兼容(至少有一个维度跨越两个连续的子空间)。使用.reshape(...)代替。

造成这种情况的代码是

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

这个错误以前没有发生过,所以是由新版本的pytorch引入的(现在我使用的pytorch是1.8.1)。

在打印数组的时候,我发现它是一个布尔数组。

解决方法

view() 前面加上.contiguous() 或用reshape 来代替view

所以把这一行

correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)

改为

correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)

correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)