TensorFLow用Saver保存和恢复变量
本文为大家分享了TensorFLow用Saver保存和恢复变量的具体代码,供大家参考,具体内容如下
建立文件tensor_save.py,保存变量v1,v2的tensor到checkpointfiles中,名称分别设置为v3,v4。
importtensorflowastf
#Createsomevariables.
v1=tf.Variable(3,name="v1")
v2=tf.Variable(4,name="v2")
#Createmodel
y=tf.add(v1,v2)
#Addanoptoinitializethevariables.
init_op=tf.initialize_all_variables()
#Addopstosaveandrestoreallthevariables.
saver=tf.train.Saver({'v3':v1,'v4':v2})
#Later,launchthemodel,initializethevariables,dosomework,savethe
#variablestodisk.
withtf.Session()assess:
sess.run(init_op)
print("v1=",v1.eval())
print("v2=",v2.eval())
#Savethevariablestodisk.
save_path=saver.save(sess,"f:/tmp/model.ckpt")
print("Modelsavedinfile:",save_path)
建立文件tensor_restror.py,将checkpointfiles中名称分别为v3,v4的tensor分别恢复到变量v3,v4中。
importtensorflowastf
#Createsomevariables.
v3=tf.Variable(0,name="v3")
v4=tf.Variable(0,name="v4")
#Createmodel
y=tf.mul(v3,v4)
#Addopstosaveandrestoreallthevariables.
saver=tf.train.Saver()
#Later,launchthemodel,usethesavertorestorevariablesfromdisk,and
#dosomeworkwiththemodel.
withtf.Session()assess:
#Restorevariablesfromdisk.
saver.restore(sess,"f:/tmp/model.ckpt")
print("Modelrestored.")
print("v3=",v3.eval())
print("v4=",v4.eval())
print("y=",sess.run(y))
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。