tensorflow 保存模型和取出中间权重例子
下面代码的功能是先训练一个简单的模型,然后保存模型,同时保存到一个pb文件当中,后续可以从pd文件里读取权重值。
importtensorflowastf
importnumpyasnp
importos
importh5py
importpickle
fromtensorflow.python.frameworkimportgraph_util
fromtensorflow.python.platformimportgfile
#设置使用指定GPU
os.environ['CUDA_VISIBLE_DEVICES']='1'
#下面这段代码是在训练好之后将所有的权重名字和权重值罗列出来,训练的时候需要注释掉
reader=tf.train.NewCheckpointReader('./model.ckpt-100')
variables=reader.get_variable_to_shape_map()
foreleinvariables:
print(ele)
print(reader.get_tensor(ele))
x=tf.placeholder(tf.float32,shape=[None,1])
y=4*x+4
w=tf.Variable(tf.random_normal([1],-1,1))
b=tf.Variable(tf.zeros([1]))
y_predict=w*x+b
loss=tf.reduce_mean(tf.square(y-y_predict))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
isTrain=False#设成True去训练模型
train_steps=100
checkpoint_steps=50
checkpoint_dir=''
saver=tf.train.Saver()#defaultstosavingallvariables-inthiscasewandb
x_data=np.reshape(np.random.rand(10).astype(np.float32),(10,1))
withtf.Session()assess:
sess.run(tf.global_variables_initializer())
ifisTrain:
foriinxrange(train_steps):
sess.run(train,feed_dict={x:x_data})
if(i+1)%checkpoint_steps==0:
saver.save(sess,checkpoint_dir+'model.ckpt',global_step=i+1)
else:
ckpt=tf.train.get_checkpoint_state(checkpoint_dir)
ifckptandckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
else:
pass
print(sess.run(w))
print(sess.run(b))
graph_def=tf.get_default_graph().as_graph_def()
#通过修改下面的函数,个人觉得理论上能够实现修改权重,但是很复杂,如果哪位有好办法,欢迎指教
output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,['Variable'])
withtf.gfile.FastGFile('./test.pb','wb')asf:
f.write(output_graph_def.SerializeToString())
withtf.Session()assess:
#对应最后一部分的写,这里能够将对应的变量取出来
withgfile.FastGFile('./test.pb','rb')asf:
graph_def=tf.GraphDef()
graph_def.ParseFromString(f.read())
res=tf.import_graph_def(graph_def,return_elements=['Variable:0'])
print(sess.run(res))
print(sess.run(graph_def))
以上这篇tensorflow保存模型和取出中间权重例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。