tensorflow 保存模型和取出中间权重例子
下面代码的功能是先训练一个简单的模型,然后保存模型,同时保存到一个pb文件当中,后续可以从pd文件里读取权重值。
importtensorflowastf importnumpyasnp importos importh5py importpickle fromtensorflow.python.frameworkimportgraph_util fromtensorflow.python.platformimportgfile #设置使用指定GPU os.environ['CUDA_VISIBLE_DEVICES']='1' #下面这段代码是在训练好之后将所有的权重名字和权重值罗列出来,训练的时候需要注释掉 reader=tf.train.NewCheckpointReader('./model.ckpt-100') variables=reader.get_variable_to_shape_map() foreleinvariables: print(ele) print(reader.get_tensor(ele)) x=tf.placeholder(tf.float32,shape=[None,1]) y=4*x+4 w=tf.Variable(tf.random_normal([1],-1,1)) b=tf.Variable(tf.zeros([1])) y_predict=w*x+b loss=tf.reduce_mean(tf.square(y-y_predict)) optimizer=tf.train.GradientDescentOptimizer(0.5) train=optimizer.minimize(loss) isTrain=False#设成True去训练模型 train_steps=100 checkpoint_steps=50 checkpoint_dir='' saver=tf.train.Saver()#defaultstosavingallvariables-inthiscasewandb x_data=np.reshape(np.random.rand(10).astype(np.float32),(10,1)) withtf.Session()assess: sess.run(tf.global_variables_initializer()) ifisTrain: foriinxrange(train_steps): sess.run(train,feed_dict={x:x_data}) if(i+1)%checkpoint_steps==0: saver.save(sess,checkpoint_dir+'model.ckpt',global_step=i+1) else: ckpt=tf.train.get_checkpoint_state(checkpoint_dir) ifckptandckpt.model_checkpoint_path: saver.restore(sess,ckpt.model_checkpoint_path) else: pass print(sess.run(w)) print(sess.run(b)) graph_def=tf.get_default_graph().as_graph_def() #通过修改下面的函数,个人觉得理论上能够实现修改权重,但是很复杂,如果哪位有好办法,欢迎指教 output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,['Variable']) withtf.gfile.FastGFile('./test.pb','wb')asf: f.write(output_graph_def.SerializeToString()) withtf.Session()assess: #对应最后一部分的写,这里能够将对应的变量取出来 withgfile.FastGFile('./test.pb','rb')asf: graph_def=tf.GraphDef() graph_def.ParseFromString(f.read()) res=tf.import_graph_def(graph_def,return_elements=['Variable:0']) print(sess.run(res)) print(sess.run(graph_def))
以上这篇tensorflow保存模型和取出中间权重例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。