解决pytorch 的state_dict()拷贝问题
先说结论
model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。
应该使用deepcopy(model.state_dict()),或将参数及时序列化到硬盘。
再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。
补充:pytorch中state_dict的理解
在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数)优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。
其实看了如下代码的输出应该就懂了
importtorch importtorch.nnasnn importtorchvision importnumpyasnp fromtorchsummaryimportsummary #Definemodel classTheModelClass(nn.Module): def__init__(self): super(TheModelClass,self).__init__() self.conv1=nn.Conv2d(3,6,5) self.pool=nn.MaxPool2d(2,2) self.conv2=nn.Conv2d(6,16,5) self.fc1=nn.Linear(16*5*5,120) self.fc2=nn.Linear(120,84) self.fc3=nn.Linear(84,10) defforward(self,x): x=self.pool(F.relu(self.conv1(x))) x=self.pool(F.relu(self.conv2(x))) x=x.view(-1,16*5*5) x=F.relu(self.fc1(x)) x=F.relu(self.fc2(x)) x=self.fc3(x) returnx #Initializemodel model=TheModelClass() #Initializeoptimizer optimizer=torch.optim.SGD(model.parameters(),lr=0.001,momentum=0.9) #Printmodel'sstate_dict print("Model'sstate_dict:") forparam_tensorinmodel.state_dict(): print(param_tensor,"\t",model.state_dict()[param_tensor].size()) #Printoptimizer'sstate_dict print("Optimizer'sstate_dict:") forvar_nameinoptimizer.state_dict(): print(var_name,"\t",optimizer.state_dict()[var_name])
输出如下:
Model'sstate_dict: conv1.weighttorch.Size([6,3,5,5]) conv1.biastorch.Size([6]) conv2.weighttorch.Size([16,6,5,5]) conv2.biastorch.Size([16]) fc1.weighttorch.Size([120,400]) fc1.biastorch.Size([120]) fc2.weighttorch.Size([84,120]) fc2.biastorch.Size([84]) fc3.weighttorch.Size([10,84]) fc3.biastorch.Size([10]) Optimizer'sstate_dict: state{} param_groups[{'lr':0.001,'momentum':0.9,'dampening':0,'weight_decay':0,'nesterov':False,'params':[2238501264336,2238501329800,2238501330016,2238501327136,2238501328576,2238501329728,2238501327928,2238501327064,2238501330808,2238501328288]}]
我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!
补充:pytorch保存模型时报错***objecthasnoattribute'state_dict'
定义了一个类BaseNet并实例化该类:
net=BaseNet()
保存net时报错objecthasnoattribute'state_dict'
torch.save(net.state_dict(),models_dir)
原因是定义类的时候不是继承nn.Module类,比如:
classBaseNet(object): def__init__(self):
把类定义改为
classBaseNet(nn.Module): def__init__(self): super(BaseNet,self).__init__()
以上为个人经验,希望能给大家一个参考,也希望大家多多支持毛票票。如有错误或未考虑完全的地方,望不吝赐教。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。