Pytorch转keras的有效方法,以FlowNet为例讲解
Pytorch凭借动态图机制,获得了广泛的使用,大有超越tensorflow的趋势,不过在工程应用上,TF仍然占据优势。有的时候我们会遇到这种情况,需要把模型应用到工业中,运用到实际项目上,TF支持的PB文件和TF的C++接口就成为了有效的工具。今天就给大家讲解一下Pytorch转成Keras的方法,进而我们也可以获得Pb文件,因为Keras是支持tensorflow的,我将会在下一篇博客讲解获得Pb文件,并使用Pb文件的方法。
PytorchToKeras
首先,我们必须有清楚的认识,网上以及github上一些所谓的pytorch转换Keras或者Keras转换成Pytorch的工具代码几乎不能运行或者有使用的局限性(比如仅仅能转换某一些模型),但是我们是可以用这些转换代码中看出一些端倪来,比如二者的参数的尺寸(shape)的形式、channel的排序(firstorlast)是否一样,掌握到差异性,就能根据这些差异自己编写转换代码,没错,自己编写转换代码,是最稳妥的办法。整个过程也就分为两个部分。笔者将会以Nvidia开源的FlowNet为例,将开源的Pytorch代码转化为Keras模型。
按照Pytorch中模型的结构,编写对应的Keras代码,用keras的函数式API,构建起来会非常方便。
把Pytorch的模型参数,按照层的名称依次赋值给Keras的模型
以上两步虽然看上去简单,但实际我也走了不少弯路。这里一个关键的地方,就是参数的shape在两个框架中是否统一,那当然是不统一的。下面我以FlowNet为例。
Pytorch中的FlowNet代码
我们仅仅展示层名称和层参数,就不把整个结构贴出来了,否则会占很多的空间,形成水文。
先看用Keras搭建的flowNet模型,直接用model.summary()输出模型信息
__________________________________________________________________________________________________ Layer(type)OutputShapeParam#Connectedto ================================================================================================== input_1(InputLayer)(None,6,512,512)0 __________________________________________________________________________________________________ conv0(Conv2D)(None,64,512,512)3520input_1[0][0] __________________________________________________________________________________________________ leaky_re_lu_1(LeakyReLU)(None,64,512,512)0conv0[0][0] __________________________________________________________________________________________________ zero_padding2d_1(ZeroPadding2D(None,64,514,514)0leaky_re_lu_1[0][0] __________________________________________________________________________________________________ conv1(Conv2D)(None,64,256,256)36928zero_padding2d_1[0][0] __________________________________________________________________________________________________ leaky_re_lu_2(LeakyReLU)(None,64,256,256)0conv1[0][0] __________________________________________________________________________________________________ conv1_1(Conv2D)(None,128,256,25673856leaky_re_lu_2[0][0] __________________________________________________________________________________________________ leaky_re_lu_3(LeakyReLU)(None,128,256,2560conv1_1[0][0] __________________________________________________________________________________________________ zero_padding2d_2(ZeroPadding2D(None,128,258,2580leaky_re_lu_3[0][0] __________________________________________________________________________________________________ conv2(Conv2D)(None,128,128,128147584zero_padding2d_2[0][0] __________________________________________________________________________________________________ leaky_re_lu_4(LeakyReLU)(None,128,128,1280conv2[0][0] __________________________________________________________________________________________________ conv2_1(Conv2D)(None,128,128,128147584leaky_re_lu_4[0][0] __________________________________________________________________________________________________ leaky_re_lu_5(LeakyReLU)(None,128,128,1280conv2_1[0][0] __________________________________________________________________________________________________ zero_padding2d_3(ZeroPadding2D(None,128,130,1300leaky_re_lu_5[0][0] __________________________________________________________________________________________________ conv3(Conv2D)(None,256,64,64)295168zero_padding2d_3[0][0] __________________________________________________________________________________________________ leaky_re_lu_6(LeakyReLU)(None,256,64,64)0conv3[0][0] __________________________________________________________________________________________________ conv3_1(Conv2D)(None,256,64,64)590080leaky_re_lu_6[0][0] __________________________________________________________________________________________________ leaky_re_lu_7(LeakyReLU)(None,256,64,64)0conv3_1[0][0] __________________________________________________________________________________________________ zero_padding2d_4(ZeroPadding2D(None,256,66,66)0leaky_re_lu_7[0][0] __________________________________________________________________________________________________ conv4(Conv2D)(None,512,32,32)1180160zero_padding2d_4[0][0] __________________________________________________________________________________________________ leaky_re_lu_8(LeakyReLU)(None,512,32,32)0conv4[0][0] __________________________________________________________________________________________________ conv4_1(Conv2D)(None,512,32,32)2359808leaky_re_lu_8[0][0] __________________________________________________________________________________________________ leaky_re_lu_9(LeakyReLU)(None,512,32,32)0conv4_1[0][0] __________________________________________________________________________________________________ zero_padding2d_5(ZeroPadding2D(None,512,34,34)0leaky_re_lu_9[0][0] __________________________________________________________________________________________________ conv5(Conv2D)(None,512,16,16)2359808zero_padding2d_5[0][0] __________________________________________________________________________________________________ leaky_re_lu_10(LeakyReLU)(None,512,16,16)0conv5[0][0] __________________________________________________________________________________________________ conv5_1(Conv2D)(None,512,16,16)2359808leaky_re_lu_10[0][0] __________________________________________________________________________________________________ leaky_re_lu_11(LeakyReLU)(None,512,16,16)0conv5_1[0][0] __________________________________________________________________________________________________ zero_padding2d_6(ZeroPadding2D(None,512,18,18)0leaky_re_lu_11[0][0] __________________________________________________________________________________________________ conv6(Conv2D)(None,1024,8,8)4719616zero_padding2d_6[0][0] __________________________________________________________________________________________________ leaky_re_lu_12(LeakyReLU)(None,1024,8,8)0conv6[0][0] __________________________________________________________________________________________________ conv6_1(Conv2D)(None,1024,8,8)9438208leaky_re_lu_12[0][0] __________________________________________________________________________________________________ leaky_re_lu_13(LeakyReLU)(None,1024,8,8)0conv6_1[0][0] __________________________________________________________________________________________________ deconv5(Conv2DTranspose)(None,512,16,16)8389120leaky_re_lu_13[0][0] __________________________________________________________________________________________________ predict_flow6(Conv2D)(None,2,8,8)18434leaky_re_lu_13[0][0] __________________________________________________________________________________________________ leaky_re_lu_14(LeakyReLU)(None,512,16,16)0deconv5[0][0] __________________________________________________________________________________________________ upsampled_flow6_to_5(Conv2DTra(None,2,16,16)66predict_flow6[0][0] __________________________________________________________________________________________________ concatenate_1(Concatenate)(None,1026,16,16)0leaky_re_lu_11[0][0] leaky_re_lu_14[0][0] upsampled_flow6_to_5[0][0] __________________________________________________________________________________________________ inter_conv5(Conv2D)(None,512,16,16)4728320concatenate_1[0][0] __________________________________________________________________________________________________ deconv4(Conv2DTranspose)(None,256,32,32)4202752concatenate_1[0][0] __________________________________________________________________________________________________ predict_flow5(Conv2D)(None,2,16,16)9218inter_conv5[0][0] __________________________________________________________________________________________________ leaky_re_lu_15(LeakyReLU)(None,256,32,32)0deconv4[0][0] __________________________________________________________________________________________________ upsampled_flow5_to4(Conv2DTran(None,2,32,32)66predict_flow5[0][0] __________________________________________________________________________________________________ concatenate_2(Concatenate)(None,770,32,32)0leaky_re_lu_9[0][0] leaky_re_lu_15[0][0] upsampled_flow5_to4[0][0] __________________________________________________________________________________________________ inter_conv4(Conv2D)(None,256,32,32)1774336concatenate_2[0][0] __________________________________________________________________________________________________ deconv3(Conv2DTranspose)(None,128,64,64)1577088concatenate_2[0][0] __________________________________________________________________________________________________ predict_flow4(Conv2D)(None,2,32,32)4610inter_conv4[0][0] __________________________________________________________________________________________________ leaky_re_lu_16(LeakyReLU)(None,128,64,64)0deconv3[0][0] __________________________________________________________________________________________________ upsampled_flow4_to3(Conv2DTran(None,2,64,64)66predict_flow4[0][0] __________________________________________________________________________________________________ concatenate_3(Concatenate)(None,386,64,64)0leaky_re_lu_7[0][0] leaky_re_lu_16[0][0] upsampled_flow4_to3[0][0] __________________________________________________________________________________________________ inter_conv3(Conv2D)(None,128,64,64)444800concatenate_3[0][0] __________________________________________________________________________________________________ deconv2(Conv2DTranspose)(None,64,128,128)395328concatenate_3[0][0] __________________________________________________________________________________________________ predict_flow3(Conv2D)(None,2,64,64)2306inter_conv3[0][0] __________________________________________________________________________________________________ leaky_re_lu_17(LeakyReLU)(None,64,128,128)0deconv2[0][0] __________________________________________________________________________________________________ upsampled_flow3_to2(Conv2DTran(None,2,128,128)66predict_flow3[0][0] __________________________________________________________________________________________________ concatenate_4(Concatenate)(None,194,128,1280leaky_re_lu_5[0][0] leaky_re_lu_17[0][0] upsampled_flow3_to2[0][0] __________________________________________________________________________________________________ inter_conv2(Conv2D)(None,64,128,128)111808concatenate_4[0][0] __________________________________________________________________________________________________ predict_flow2(Conv2D)(None,2,128,128)1154inter_conv2[0][0] __________________________________________________________________________________________________ up_sampling2d_1(UpSampling2D)(None,2,512,512)0predict_flow2[0][0]
再看看Pytorch搭建的flownet模型
(conv0):Sequential( (0):Conv2d(6,64,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (conv1):Sequential( (0):Conv2d(64,64,kernel_size=(3,3),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (conv1_1):Sequential( (0):Conv2d(64,128,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (conv2):Sequential( (0):Conv2d(128,128,kernel_size=(3,3),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (conv2_1):Sequential( (0):Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (conv3):Sequential( (0):Conv2d(128,256,kernel_size=(3,3),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (conv3_1):Sequential( (0):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (conv4):Sequential( (0):Conv2d(256,512,kernel_size=(3,3),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (conv4_1):Sequential( (0):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (conv5):Sequential( (0):Conv2d(512,512,kernel_size=(3,3),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (conv5_1):Sequential( (0):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (conv6):Sequential( (0):Conv2d(512,1024,kernel_size=(3,3),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (conv6_1):Sequential( (0):Conv2d(1024,1024,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (deconv5):Sequential( (0):ConvTranspose2d(1024,512,kernel_size=(4,4),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (deconv4):Sequential( (0):ConvTranspose2d(1026,256,kernel_size=(4,4),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (deconv3):Sequential( (0):ConvTranspose2d(770,128,kernel_size=(4,4),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (deconv2):Sequential( (0):ConvTranspose2d(386,64,kernel_size=(4,4),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) (inter_conv5):Sequential( (0):Conv2d(1026,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ) (inter_conv4):Sequential( (0):Conv2d(770,256,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ) (inter_conv3):Sequential( (0):Conv2d(386,128,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ) (inter_conv2):Sequential( (0):Conv2d(194,64,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ) (predict_flow6):Conv2d(1024,2,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (predict_flow5):Conv2d(512,2,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (predict_flow4):Conv2d(256,2,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (predict_flow3):Conv2d(128,2,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (predict_flow2):Conv2d(64,2,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (upsampled_flow6_to_5):ConvTranspose2d(2,2,kernel_size=(4,4),stride=(2,2),padding=(1,1)) (upsampled_flow5_to_4):ConvTranspose2d(2,2,kernel_size=(4,4),stride=(2,2),padding=(1,1)) (upsampled_flow4_to_3):ConvTranspose2d(2,2,kernel_size=(4,4),stride=(2,2),padding=(1,1)) (upsampled_flow3_to_2):ConvTranspose2d(2,2,kernel_size=(4,4),stride=(2,2),padding=(1,1)) (upsample1):Upsample(scale_factor=4.0,mode=bilinear) ) conv0Sequential( (0):Conv2d(6,64,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) conv0.0Conv2d(6,64,kernel_size=(3,3),stride=(1,1),padding=(1,1)) conv0.1LeakyReLU(negative_slope=0.1,inplace) conv1Sequential( (0):Conv2d(64,64,kernel_size=(3,3),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) conv1.0Conv2d(64,64,kernel_size=(3,3),stride=(2,2),padding=(1,1)) conv1.1LeakyReLU(negative_slope=0.1,inplace) conv1_1Sequential( (0):Conv2d(64,128,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) conv1_1.0Conv2d(64,128,kernel_size=(3,3),stride=(1,1),padding=(1,1)) conv1_1.1LeakyReLU(negative_slope=0.1,inplace) conv2Sequential( (0):Conv2d(128,128,kernel_size=(3,3),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) conv2.0Conv2d(128,128,kernel_size=(3,3),stride=(2,2),padding=(1,1)) conv2.1LeakyReLU(negative_slope=0.1,inplace) conv2_1Sequential( (0):Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) conv2_1.0Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1)) conv2_1.1LeakyReLU(negative_slope=0.1,inplace) conv3Sequential( (0):Conv2d(128,256,kernel_size=(3,3),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) conv3.0Conv2d(128,256,kernel_size=(3,3),stride=(2,2),padding=(1,1)) conv3.1LeakyReLU(negative_slope=0.1,inplace) conv3_1Sequential( (0):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) conv3_1.0Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1)) conv3_1.1LeakyReLU(negative_slope=0.1,inplace) conv4Sequential( (0):Conv2d(256,512,kernel_size=(3,3),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) conv4.0Conv2d(256,512,kernel_size=(3,3),stride=(2,2),padding=(1,1)) conv4.1LeakyReLU(negative_slope=0.1,inplace) conv4_1Sequential( (0):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) conv4_1.0Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) conv4_1.1LeakyReLU(negative_slope=0.1,inplace) conv5Sequential( (0):Conv2d(512,512,kernel_size=(3,3),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) conv5.0Conv2d(512,512,kernel_size=(3,3),stride=(2,2),padding=(1,1)) conv5.1LeakyReLU(negative_slope=0.1,inplace) conv5_1Sequential( (0):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) conv5_1.0Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) conv5_1.1LeakyReLU(negative_slope=0.1,inplace) conv6Sequential( (0):Conv2d(512,1024,kernel_size=(3,3),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) conv6.0Conv2d(512,1024,kernel_size=(3,3),stride=(2,2),padding=(1,1)) conv6.1LeakyReLU(negative_slope=0.1,inplace) conv6_1Sequential( (0):Conv2d(1024,1024,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) conv6_1.0Conv2d(1024,1024,kernel_size=(3,3),stride=(1,1),padding=(1,1)) conv6_1.1LeakyReLU(negative_slope=0.1,inplace) deconv5Sequential( (0):ConvTranspose2d(1024,512,kernel_size=(4,4),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) deconv5.0ConvTranspose2d(1024,512,kernel_size=(4,4),stride=(2,2),padding=(1,1)) deconv5.1LeakyReLU(negative_slope=0.1,inplace) deconv4Sequential( (0):ConvTranspose2d(1026,256,kernel_size=(4,4),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) deconv4.0ConvTranspose2d(1026,256,kernel_size=(4,4),stride=(2,2),padding=(1,1)) deconv4.1LeakyReLU(negative_slope=0.1,inplace) deconv3Sequential( (0):ConvTranspose2d(770,128,kernel_size=(4,4),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) deconv3.0ConvTranspose2d(770,128,kernel_size=(4,4),stride=(2,2),padding=(1,1)) deconv3.1LeakyReLU(negative_slope=0.1,inplace) deconv2Sequential( (0):ConvTranspose2d(386,64,kernel_size=(4,4),stride=(2,2),padding=(1,1)) (1):LeakyReLU(negative_slope=0.1,inplace) ) deconv2.0ConvTranspose2d(386,64,kernel_size=(4,4),stride=(2,2),padding=(1,1)) deconv2.1LeakyReLU(negative_slope=0.1,inplace) inter_conv5Sequential( (0):Conv2d(1026,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ) inter_conv5.0Conv2d(1026,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) inter_conv4Sequential( (0):Conv2d(770,256,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ) inter_conv4.0Conv2d(770,256,kernel_size=(3,3),stride=(1,1),padding=(1,1)) inter_conv3Sequential( (0):Conv2d(386,128,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ) inter_conv3.0Conv2d(386,128,kernel_size=(3,3),stride=(1,1),padding=(1,1)) inter_conv2Sequential( (0):Conv2d(194,64,kernel_size=(3,3),stride=(1,1),padding=(1,1)) )
因为Pytorch模型用name_modules()输出不是按顺序的,动态图机制决定了只有在有数据流动之后才知道走过的路径。所以上面的顺序也是乱的。但我想表明的是,我用Keras搭建的模型确实是根据官方开源的Pytorch模型搭建的。
模型搭建完毕之后,就到了关键的步骤:给Keras模型赋值。
给Keras模型赋值
这个步骤其实注意三个点
Pytorch是channels_first的,Keras默认是channels_last,在代码开头加上这两句:
K.set_image_data_format(‘channels_first')
K.set_learning_phase(0)
众所周知,卷积层的权重是一个4维张量,那么,在Pytorch和keras中,卷积核的权重的形式是否一致的,那自然是不一致的,要不然我为啥还要写这一点。那么就涉及到Pytorch权重的变形。
既然卷积层权重形式在两个框架是不一致的,转置卷积自然也是不一致的。
我们先看看卷积层在两个框架中的形式
keras的卷积层权重形式
我们用以下代码看keras卷积层权重形式
forlinmodel.layers: print(l.name) fori,winenumerate(l.get_weights()): print('%d'%i,w.shape)
第一个卷积层输出如下0之后是卷积权重的shape,1之后的是偏置项
conv0
0(3,3,6,64)
1(64,)
所以Keras的卷积层权重形式是[height,width,input_channels,out_channels]
Pytorch的卷积层权重形式
net=FlowNet2SD() forn,minnet.named_parameters(): print(n) print(m.data.size())
conv0.0.weight
torch.Size([64,6,3,3])
conv0.0.bias
torch.Size([64])
用上面的代码得到所有层的参数的shape,同样找到第一个卷积层的参数,查看shape。
通过对比我们可以发现,Pytorch的卷积层shape是[out_channels,input_channels,height,width]的形式。
那么我们在取出Pytorch权重之后,需要用np.transpose改变一下权重的排序,才能送到Keras模型对应的层上。
Keras中转置卷积权重形式
deconv4
0(4,4,256,1026)
1(256,)
代码仍然和上面一样,找到转置卷积的对应的位置,查看一下
可以看出在Keras中,转置卷积形式是[height,width,out_channels,input_channels]
Pytorch中转置卷积权重形式
deconv4.0.weight
torch.Size([1026,256,4,4])
deconv4.0.bias
torch.Size([256])
代码仍然和上面一样,找到转置卷积的对应的位置,查看一下
可以看出在Pytorch中,转置卷积形式是[input_channels,out_channels,height,width]
小结
对于卷积层来说,Pytorch的权重需要使用
np.transpose(weight.data.numpy(),[2,3,1,0])
才能赋值给keras模型对应的层的权重。
对于转置卷积来说,通过对比其实也是一样的。不信你去试试嘛。O(∩_∩)O哈哈~
对于偏置项,两种模块都是一维的向量,不需要处理。
有的情况还可能需要通道颠倒一下,但是很少需要这样做。
weights[::-1,::-1,:,:]
赋值
结束了预处理之后,我们就进入第二步,开始赋值了。
先看预处理的代码:
fork,vinweights_from_torch.items(): if'bias'notink: weights_from_torch[k]=v.data.numpy().transpose(2,3,1,0)
赋值代码我只截了一部分供大家参考:
k_model=k_model() forlayerink_model.layers: current_layer_name=layer.name ifcurrent_layer_name=='conv0': weights=[weights_from_torch['conv0.0.weight'],weights_from_torch['conv0.0.bias']] layer.set_weights(weights) elifcurrent_layer_name=='conv1': weights=[weights_from_torch['conv1.0.weight'],weights_from_torch['conv1.0.bias']] layer.set_weights(weights) elifcurrent_layer_name=='conv1_1': weights=[weights_from_torch['conv1_1.0.weight'],weights_from_torch['conv1_1.0.bias']] layer.set_weights(weights)
首先就是定义Keras模型,用layers获得所有层的迭代器。
遍历迭代器,对一个层赋予相应的值。
赋值需要用save_weights,其参数需要是一个列表,形式和get_weights的返回结果一致,即[conv_weights,bias_weights]
最后祝愿大家能实现自己模型的迁移。工程开源在了个人Github,有详细的使用介绍,并且包含使用数据,大家可以直接运行。
以上这篇Pytorch转keras的有效方法,以FlowNet为例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。