解决tensorflow模型参数保存和加载的问题
终于找到bug原因!记一下;还是不熟悉平台的原因造成的!
Q:为什么会出现两个模型对象在同一个文件中一起运行,当直接读取他们分开运行时训练出来的模型会出错,而且总是有一个正确,一个读取错误?而直接在同一个文件又训练又重新加载模型预测不出错,而且更诡异的是此时用分文件里的对象加载模型不会出错?
model.py,里面含有ModelV和ModelP,另外还有modelP.py和modelV.py分别只含有ModelP和ModeV这两个对象,先使用modelP.py和modelV.py分别训练好模型,然后再在model.py里加载进来:
#-*-coding:utf8-*- importtensorflowastf classModelV(): def__init__(self): self.v1=tf.Variable(66,name="v1") self.v2=tf.Variable(77,name="v2") self.save_path="model_v/model.ckpt" self.init=tf.global_variables_initializer() self.saver=tf.train.Saver() self.sess=tf.Session() deftrain(self): self.sess.run(self.init) print'v2',self.v2.eval(self.sess) self.saver.save(self.sess,self.save_path) print"ModelVsaved." defpredict(self): all_vars=tf.trainable_variables() forvinall_vars: print(v.name) self.saver.restore(self.sess,self.save_path) print"ModelVrestored." print'v2',self.v2.eval(self.sess) print'------------------------------------------------------------------' classModelP(): def__init__(self): self.p1=tf.Variable(88,name="p1") self.p2=tf.Variable(99,name="p2") self.save_path="model_p/model.ckpt" self.init=tf.global_variables_initializer() self.saver=tf.train.Saver() self.sess=tf.Session() deftrain(self): self.sess.run(self.init) print'p2',self.p2.eval(self.sess) self.saver.save(self.sess,self.save_path) print"ModelPsaved." defpredict(self): all_vars=tf.trainable_variables() forvinall_vars: printv.name self.saver.restore(self.sess,self.save_path) print"ModelPrestored." print'p2',self.p2.eval(self.sess) print'---------------------------------------------------------------------' if__name__=='__main__': v=ModelV() p=ModelP() v.predict() #v.train() p.predict() #p.train()
这里tf.global_variables_initializer()很关键!尽管你是分别在对象ModelP和ModelV内部分配和定义的tf.Variable(),即v1v2和p1p2,但是对tf这个模块而言,这些都是全局变量,可以通过以下代码查看所有的变量,你就会发现同一个文件中同时运行ModelP和ModelV在初始化之后都打印出了一样的变量,这个是问题的关键所在:
all_vars=tf.trainable_variables() forvinall_vars: print(v.name)
错误。你可以交换modelP和modelV初始化的顺序,看看错误信息的变化
v1:0 v2:0 p1:0 p2:0 ModelVrestored. v277 v1:0 v2:0 p1:0 p2:0 Wtensorflow/core/framework/op_kernel.cc:975]Notfound:Keyv2notfoundincheckpoint Wtensorflow/core/framework/op_kernel.cc:975]Notfound:Keyv1notfoundincheckpoint
实际上,分开运行时,模型保存的参数是正确的,因为在一个模型里的Variable就只有v1v2或者p1p2;但是在一个文件同时运行的时候,模型参数实际上保存的是v1v2p1p2四个,因为在默认情况下,创建的Saver,会直接保存所有的参数。而Saver.restore()又是默认(无Variable参数列表时)按照已经定义好的全局模型变量来加载对应的参数值,在进行ModelV.predict时,按照顺序(从debug可以看出,应该是按照参数顺序一次检测)在模型文件中查找相应的key,此时能够找到对应的v1v2,加载成功,但是在ModelP.predict时,在model_p的模型文件中找不到v1和v2,只有p1和p2,此时就会报错;不过这里的第一次加载还有p1p2找不到没有报错,解释不通,未完待续
Saver.save()和Saver.restore()是一对,分别只保存和加载模型的参数,但是模型的结构怎么知道呢?必须是你定义好了,而且要和保存的模型匹配才能加载;
如果想要在不定义模型的情况下直接加载出模型结构和模型参数值,使用
#加载结构,即模型参数变量等 new_saver=tf.train.import_meta_graph("model_v/model.ckpt.meta") print"ModelVconstruct" all_vars=tf.trainable_variables() forvinall_vars: printv.name #printv.name,v.eval(self.sess)#v都还未初始化,不能求值 #加载模型参数变量的值 new_saver.restore(self.sess,tf.train.latest_checkpoint('model_v/')) print"ModelVrestored." all_vars=tf.trainable_variables() forvinall_vars: printv.name,v.eval(self.sess)
加载结构,即模型参数变量等完成后,就会有变量了,但是不能访问他的值,因为还未赋值,然后再restore一次即可得到值了
那么上述错误的解决方法就是这个改进版本的model.py;其实tf.train.Saver是可以带参数的,他可以保存你想要保存的模型参数,如果不带参数,很可能就会保存tf.trainable_variables()所有的variable,而tf.trainable_variables()又是从tf全局得到的,因此只要在模型保存和加载时,构造对应的带参数的tf.train.Saver即可,这样就会保存和加载正确的模型了
#-*-coding:utf8-*- importtensorflowastf classModelV(): def__init__(self): self.v1=tf.Variable(66,name="v1") self.v2=tf.Variable(77,name="v2") self.save_path="model_v/model.ckpt" self.init=tf.global_variables_initializer() self.sess=tf.Session() deftrain(self): saver=tf.train.Saver([self.v1,self.v2]) self.sess.run(self.init) print'v2',self.v2.eval(self.sess) saver.save(self.sess,self.save_path) print"ModelVsaved." defpredict(self): saver=tf.train.Saver([self.v1,self.v2]) all_vars=tf.trainable_variables() forvinall_vars: printv.name v_vars=[vforvinall_varsifv.name=='v1:0'orv.name=='v2:0'] print"ModelVrestored." saver.restore(self.sess,self.save_path) forvinv_vars: printv.name,v.eval(self.sess) print'v2',self.v2.eval(self.sess) print'------------------------------------------------------------------' classModelP(): def__init__(self): self.p1=tf.Variable(88,name="p1") self.p2=tf.Variable(99,name="p2") self.save_path="model_p/model.ckpt" self.init=tf.global_variables_initializer() self.sess=tf.Session() deftrain(self): saver=tf.train.Saver([self.p1,self.p2]) self.sess.run(self.init) print'p2',self.p2.eval(self.sess) saver.save(self.sess,self.save_path) print"ModelPsaved." defpredict(self): saver=tf.train.Saver([self.p1,self.p2]) all_vars=tf.trainable_variables() p_vars=[vforvinall_varsifv.name=='p1:0'orv.name=='p2:0'] forvinall_vars: printv.name #printv.name,v.eval(self.sess) saver.restore(self.sess,self.save_path) print"ModelPrestored." forpinp_vars: printp.name,p.eval(self.sess) print'p2',self.p2.eval(self.sess) print'----------------------------------------------------------' if__name__=='__main__': v=ModelV() p=ModelP() v.predict() #v.train() p.predict() #p.train()
小结:构造的Saver最好带Variable参数,这样保证保存和加载能够正确执行
以上这篇解决tensorflow模型参数保存和加载的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。