修改pytorch的pth模型中的参数名

370 阅读1分钟
  • 修改目标:将所有以fem开头的key,都修改成以ccf开头。
  • 修改方式:dict['目标参数名']=dict.pop('原参数名')
dict = torch.load('./back.pth')
# 在该模型中,要寻找的key都在"net"参数下,所以要先用"dict['net']"来限定
for k, v in list(dict['net'].items()):  # 如果不加一层list,修改动作会导致报错:【RuntimeError: OrderedDict mutated during iteration】  
    if k[0:3]=='fem':
        dict['net']['cff'+k[3:]]=dict['net'].pop(k)
torch.save(dict, './save.pth')