tensorflow saver 保存和恢复指定 tensor的实例讲解
在实践中经常会遇到这样的情况:
1、用简单的模型预训练参数
2、把预训练的参数导入复杂的模型后训练复杂的模型
这时就产生一个问题:
如何加载预训练的参数。
下面就是我的总结。
为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。
卷积层的实现代码如下:
importtensorflowastf #PS:本篇的重担是saver,不过为了方便阅读还是说明下参数 #参数 #name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope #input_data:输入数据 #width,high:卷积小窗口的宽、高 #deep_before,deep_after:卷积前后的神经元数量 #stride:卷积小窗口的移动步长 defmake_conv(name,input_data,width,high,deep_before,deep_after,stride,padding_type='SAME'): globalparameters withtf.name_scope(name)asscope: weights=tf.Variable(tf.truncated_normal([width,high,deep_before,deep_after], dtype=tf.float32,stddev=0.01),trainable=True,name='weights') biases=tf.Variable(tf.constant(0.1,shape=[deep_after]),trainable=True,name='biases') conv=tf.nn.conv2d(input_data,weights,[1,stride,stride,1],padding=padding_type) bias=tf.add(conv,biases) bias=batch_norm(bias,deep_after,1)#batch_norm是自己写的batchnorm函数 conv=tf.maximum(0.1*bias,bias) returnconv
简单的预训练模型就下面一句话
conv1=make_conv('simple-conv1',images,3,3,3,32,1)
复杂的模型是两个卷基层,如下:
conv1=make_conv('complex-conv1',images,3,3,3,32,1) pool1=make_max_pool('layer1-pool1',conv1,2,2) conv2=make_conv('complex-conv2',pool1,3,3,32,64,1)
这时简简单单的在预训练模型中:
saver=tf.train.Saver() withtf.Session()assess: saver.save(sess,'model.ckpt')
就不行了,因为:
1,如果你在预训练模型中使用下面的话打印所有tensor
all_v=tf.global_variables() foriinall_v:printi
会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:
同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。
2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。
解决方法:
1,在预训练模型中定义全局变量
parm_dict={}
并在“returnconv”上面添加下面两行
parm_dict['complex-conv1/weights']=weights parm_dict['complex-conv1/']=biases
然后在定义saver时使用下面这句话:
saver=tf.train.Saver(parm_dict)
这样保存后的模型文件就对应到复杂模型上了。
2,在复杂模型中定义全局变量
parameters=[]
并在“returnconv”上面添加下面行
parameters+=[weights,biases]
然后判断如果是第二个卷积层就不更新parameters。
接着在定义saver时使用下面这句话:
saver=tf.train.Saver(parameters)
这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~complex-conv1/Variable_3统统滚一边去(上面红色部分)。
最后使用下面的代码加载就可以了
withtf.Session()assess: ckpt=tf.train.get_checkpoint_state('.') ifckptandckpt.model_checkpoint_path: saver.restore(sess,ckpt.model_checkpoint_path) else: print'nosaver.' exit()
以上这篇tensorflowsaver保存和恢复指定tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。