Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解
一、保存:
graph_util.convert_variables_to_constants可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量
参数2:GraphDef对象,它描述了计算网络
参数3:Graph图中需要输出的节点的名称的列表
返回值:精简版的GraphDef对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行化为字节流,然后写入文件里:
constant_graph=graph_util.convert_variables_to_constants(sess,sess.graph_def,['sum_operation']) withopen(pbName,mode='wb')asf: f.write(constant_graph.SerializeToString())
需要指出的是,如果原始张量(包含在参数1和参数2中的组成部分)不参与参数3指定的输出节点列表所指定的张量计算的话,这些张量将不会存在返回的GraphDef对象里,也不会被串行化写入pb文件。
二、恢复:
恢复时,创建一个GraphDef,然后从上述的文件里加载进来,接着输入到当前的session:
graph0=tf.GraphDef() withopen(pbName,mode='rb')asf: graph0.ParseFromString(f.read()) tf.import_graph_def(graph0,name='')
三、代码:
importtensorflowastf fromtensorflow.python.frameworkimportgraph_util pbName='graphA.pb' defgraphCreate(): withtf.Session()assess: var1=tf.placeholder(tf.int32,name='var1') var2=tf.Variable(20,name='var2')#实参name='var2'指定了操作名,该操作返回的张量名是在 #'var2'后面:0,即var2:0是返回的张量名,也就是说变量 #var2的名称是'var2:0' var3=tf.Variable(30,name='var3') var4=tf.Variable(40,name='var4') var4op=tf.assign(var4,1000,name='var4op1') sum=tf.Variable(4,name='sum') sum=tf.add(var1,var2,name='var1_var2') sum=tf.add(sum,var3,name='sum_var3') sumOps=tf.add(sum,var4,name='sum_operation') oper=tf.get_default_graph().get_operations() withopen('operation.csv','wt')asf: s='name,type,output\n' f.write(s) foroinoper: s=o.name s+=','+o.type inp=o.inputs oup=o.outputs foriipininp: s#s+=','+str(iip) foriopinoup: s+=','+str(iop) s+='\n' f.write(s) forvarintf.global_variables(): print('variable=>',var.name)#张量是tf.Variable/tf.Add之类操作的结果, #张量的名字使用操作名加:0来表示 init=tf.global_variables_initializer() sess.run(init) sess.run(var4op) print('sum_operationresultisTensor',sess.run(sumOps,feed_dict={var1:1})) constant_graph=graph_util.convert_variables_to_constants(sess,sess.graph_def,['sum_operation']) withopen(pbName,mode='wb')asf: f.write(constant_graph.SerializeToString()) defgraphGet(): print("startget:") withtf.Graph().as_default(): graph0=tf.GraphDef() withopen(pbName,mode='rb')asf: graph0.ParseFromString(f.read()) tf.import_graph_def(graph0,name='') withtf.Session()assess: init=tf.global_variables_initializer() sess.run(init) v1=sess.graph.get_tensor_by_name('var1:0') v2=sess.graph.get_tensor_by_name('var2:0') v3=sess.graph.get_tensor_by_name('var3:0') v4=sess.graph.get_tensor_by_name('var4:0') sumTensor=sess.graph.get_tensor_by_name("sum_operation:0") print('sumTensoris:',sumTensor) print(sess.run(sumTensor,feed_dict={v1:1})) graphCreate() graphGet()
四、保存pb函数代码里的操作名称/类型/返回的张量:
以上这篇Tensorflow使用pb文件保存(恢复)模型计算图和参数实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。