RuntimeError:视图大小与输入张量的大小和跨度不兼容(至少有一个维度跨越两个连续的子空间)。使用.reshape(...)代替。
造成这种情况的代码是
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
这个错误以前没有发生过,所以是由新版本的pytorch引入的(现在我使用的pytorch是1.8.1)。
在打印数组的时候,我发现它是一个布尔数组。
解决方法
在view() 前面加上.contiguous() 或用reshape 来代替view
所以把这一行
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
改为
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
或
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)