pytorch 状态字典:state_dict使用详解
pytorch中的state_dict是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)
(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)
优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr,momentum,weight_decay等)
备注:
1)state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的PATH="./***.pt"
torch.save(model.state_dict(),PATH)
2)load_state_dict也是model或optimizer之后pytorch自动具备的函数,可以直接调用
model=TheModelClass(*args,**kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
注意:model.eval()的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batchnormalization层"才会进入evalution模态.而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.
模态字典(state_dict)的保存(model是一个网络结构类的对象)
1.1)仅保存学习到的参数,用以下命令
torch.save(model.state_dict(),PATH)
1.2)加载model.state_dict,用以下命令
model=TheModelClass(*args,**kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
备注:model.load_state_dict的操作对象是一个具体的对象,而不能是文件名
2.1)保存整个model的状态,用以下命令
torch.save(model,PATH)
2.2)加载整个model的状态,用以下命令:
#Modelclassmustbedefinedsomewhere model=torch.load(PATH) model.eval()
state_dict是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项
如何仅加载某一层的训练的到的参数(某一层的state)
Ifyouwanttoloadparametersfromonelayertoanother,butsomekeysdonotmatch,simplychangethenameoftheparameterkeysinthestate_dictthatyouareloadingtomatchthekeysinthemodelthatyouareloadinginto.
conv1_weight_state=torch.load('./model_state_dict.pt')['conv1.weight']
加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)
forparaminlist(model.pretrained.parameters()): param.requires_grad=False
注意:requires_grad的操作对象是tensor.
疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False
回答:经测试,不可以.model.conv1没有requires_grad属性.
全部测试代码:
#-*-coding:utf-8-*-
importtorch
importtorch.nnasnn
importtorch.nn.functionalasF
importtorch.optimasoptim
#definemodel
classTheModelClass(nn.Module):
def__init__(self):
super(TheModelClass,self).__init__()
self.conv1=nn.Conv2d(3,6,5)
self.pool=nn.MaxPool2d(2,2)
self.conv2=nn.Conv2d(6,16,5)
self.fc1=nn.Linear(16*5*5,120)
self.fc2=nn.Linear(120,84)
self.fc3=nn.Linear(84,10)
defforward(self,x):
x=self.pool(F.relu(self.conv1(x)))
x=self.pool(F.relu(self.conv2(x)))
x=x.view(-1,16*5*5)
x=F.relu(self.fc1(x))
x=F.relu(self.fc2(x))
x=self.fc3(x)
returnx
#initialmodel
model=TheModelClass()
#initializetheoptimizer
optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
#printthemodel'sstate_dict
print("model'sstate_dict:")
forparam_tensorinmodel.state_dict():
print(param_tensor,'\t',model.state_dict()[param_tensor].size())
print("\noptimizer'sstate_dict")
forvar_nameinoptimizer.state_dict():
print(var_name,'\t',optimizer.state_dict()[var_name])
print("\nprintparticularparam")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
#model_2=TheModelClass()
#model_2.load_state_dict(torch.load('./model_state_dict'))
#model.eval()
#print('\n',model_2.conv1.weight)
#print((model_2.conv1.weight==model.conv1.weight).size())
##仅仅加载某一层的参数
conv1_weight_state=torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
model_2=TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)
以上这篇pytorch状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。