本文已参与「新人创作礼」活动,一起开启掘金创作之路。
pytorch模型参数(卷积)转不到cuda
错误提示:RuntimeError: Input type (torch.cuda.FloatTensor) and weight type(torch.FloatTensor)
这里先给出结论:不能识别python原生的list,换成nn.ModuleList()即可
错误情况:
class my_layer(nn.Module):
def __init__(self, num_layer):
super(my_layer, self).__init__(**kwargs)
self.dim=[1000, 500, 200]
self.length = len(dim) - 1
self.layers = []
for i in range(self.length):
self.layers.append(nn.Conv2d(dim[i], dim[i+1], 3, padding=1))
def forward(x):
for i in range(self.length):
x = self.layers[i](x)
return x
这里自定义一个网络,里面每个卷积(模块)用list进行存储,运行后会出现上述错误:RuntimeError: Input type (torch.cuda.FloatTensor) and weight type(torch.FloatTensor)
正确情况一(麻烦):
class my_layer(nn.Module):
def __init__(self, num_layer):
super(my_layer, self).__init__(**kwargs)
self.dim=[1000, 500, 200]
self.length = len(dim) - 1
self.layername = []
for i in range(self.length):
layer = nn.Conv2d(nn.Conv2d(dim[i], dim[i+1], 3, padding=1)
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, layer)
self.layername.append(layer_name )
def forward(x):
for i, layername in enumerate(self.layername):
layer = getattr(self, layername)
x = layer(x)
return x
在这里我们同样自定义一个网络,只是对每个卷积(模块)进行存储时,我们选择用列表存储模块名,而不是存储模块,然后将模块与模块名用self.add_module(layer_name, layer)添加,在forward函数中使用模块时,我们利用列表存储的模块名,使用getattr(self, layername)方法进行调用模块,运行后不再报错。
正确情况二:(推荐)
方法一很麻烦,而且虽然说解决了问题,但其实是规避了问题,这里其实我们将python原生list换成self.layers = nn.ModuleList(),其他都不变,问题就解决了
问题就是模型加载到显卡的时候不能识别python原生list
class my_layer(nn.Module):
def __init__(self, num_layer):
super(my_layer, self).__init__(**kwargs)
self.dim=[1000, 500, 200]
self.length = len(dim) - 1
self.layers = nn.ModuleList()
for i in range(self.length):
self.layers.append(nn.Conv2d(dim[i], dim[i+1], 3, padding=1))
def forward(x):
for i in range(self.length):
x = self.layers[i](x)
return x