pytorch在fintune时将sequential中的层输出方法,以vgg为例
有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈!
首先pytorch自带的vgg16模型的网络结构如下:
VGG( (features):Sequential( (0):Conv2d(3,64,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (1):ReLU(inplace) (2):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (3):ReLU(inplace) (4):MaxPool2d(kernel_size=(2,2),stride=(2,2),dilation=(1,1)) (5):Conv2d(64,128,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (6):ReLU(inplace) (7):Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (8):ReLU(inplace) (9):MaxPool2d(kernel_size=(2,2),stride=(2,2),dilation=(1,1)) (10):Conv2d(128,256,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (11):ReLU(inplace) (12):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (13):ReLU(inplace) (14):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (15):ReLU(inplace) (16):MaxPool2d(kernel_size=(2,2),stride=(2,2),dilation=(1,1)) (17):Conv2d(256,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (18):ReLU(inplace) (19):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (20):ReLU(inplace) (21):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (22):ReLU(inplace) (23):MaxPool2d(kernel_size=(2,2),stride=(2,2),dilation=(1,1)) (24):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (25):ReLU(inplace) (26):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (27):ReLU(inplace) (28):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (29):ReLU(inplace) (30):MaxPool2d(kernel_size=(2,2),stride=(2,2),dilation=(1,1)) ) (classifier):Sequential( (0):Linear(in_features=25088,out_features=4096) (1):ReLU(inplace) (2):Dropout(p=0.5) (3):Linear(in_features=4096,out_features=4096) (4):ReLU(inplace) (5):Dropout(p=0.5) (6):Linear(in_features=4096,out_features=1000) ) )
我们需要fintunevgg16的features部分,并且我希望把3,8,15,22,29这五个作为输出进一步操作。我的想法是自己写一个vgg网络,这个网络参数与pytorch的网络一致但是保证我们需要的层输出在sequential外。于是我写的网络如下:
classour_vgg(nn.Module): def__init__(self): super(our_vgg,self).__init__() self.conv1=nn.Sequential( #conv1 nn.Conv2d(3,64,3,padding=35), nn.ReLU(inplace=True), nn.Conv2d(64,64,3,padding=1), nn.ReLU(inplace=True), ) self.conv2=nn.Sequential( #conv2 nn.MaxPool2d(2,stride=2,ceil_mode=True),#1/2 nn.Conv2d(64,128,3,padding=1), nn.ReLU(inplace=True), nn.Conv2d(128,128,3,padding=1), nn.ReLU(inplace=True), ) self.conv3=nn.Sequential( #conv3 nn.MaxPool2d(2,stride=2,ceil_mode=True),#1/4 nn.Conv2d(128,256,3,padding=1), nn.ReLU(inplace=True), nn.Conv2d(256,256,3,padding=1), nn.ReLU(inplace=True), nn.Conv2d(256,256,3,padding=1), nn.ReLU(inplace=True), ) self.conv4=nn.Sequential( #conv4 nn.MaxPool2d(2,stride=2,ceil_mode=True),#1/8 nn.Conv2d(256,512,3,padding=1), nn.ReLU(inplace=True), nn.Conv2d(512,512,3,padding=1), nn.ReLU(inplace=True), nn.Conv2d(512,512,3,padding=1), nn.ReLU(inplace=True), ) self.conv5=nn.Sequential( #conv5 nn.MaxPool2d(2,stride=2,ceil_mode=True),#1/16 nn.Conv2d(512,512,3,padding=1), nn.ReLU(inplace=True), nn.Conv2d(512,512,3,padding=1), nn.ReLU(inplace=True), nn.Conv2d(512,512,3,padding=1), nn.ReLU(inplace=True), ) defforward(self,x): conv1=self.conv1(x) conv2=self.conv2(conv1) conv3=self.conv3(conv2) conv4=self.conv4(conv3) conv5=self.conv5(conv4) returnconv5
接着就是copyweights了:
defconvert_vgg(vgg16):#vgg16是pytorch自带的 net=our_vgg()#我写的vgg vgg_items=net.state_dict().items() vgg16_items=vgg16.items() pretrain_model={} j=0 fork,vinnet.state_dict().iteritems():#按顺序依次填入 v=vgg16_items[j][1] k=vgg_items[j][0] pretrain_model[k]=v j+=1 returnpretrain_model ##net是我们最后使用的网络,也是我们想要放置weights的网络 net=net() print('loadtheweightfromvgg') pretrained_dict=torch.load('vgg16.pth') pretrained_dict=convert_vgg(pretrained_dict) model_dict=net.state_dict() #1.把不属于我们需要的层剔除 pretrained_dict={k:vfork,vinpretrained_dict.items()ifkinmodel_dict} #2.把参数存入已经存在的model_dict model_dict.update(pretrained_dict) #3.加载更新后的model_dict net.load_state_dict(model_dict) print('copytheweightsucessfully')
这样我就基本达成目标了,注意net也就是我们要使用的网络fintune部分需要和our_vgg一致。
以上这篇pytorch在fintune时将sequential中的层输出方法,以vgg为例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。