keras模型保存为tensorflow的二进制模型方式
最近需要将使用keras训练的模型移植到手机上使用,因此需要转换到tensorflow的二进制模型。
折腾一下午,终于找到一个合适的方法,废话不多说,直接上代码:
#coding=utf-8 importsys fromkeras.modelsimportload_model importtensorflowastf importos importos.pathasosp fromkerasimportbackendasK deffreeze_session(session,keep_var_names=None,output_names=None,clear_devices=True): """ Freezesthestateofasessionintoaprunnedcomputationgraph. Createsanewcomputationgraphwherevariablenodesarereplacedby constantstakingtheircurrentvalueinthesession.Thenewgraphwillbe prunnedsosubgraphsthatarenotneccesarytocomputetherequested outputsareremoved. @paramsessionTheTensorFlowsessiontobefrozen. @paramkeep_var_namesAlistofvariablenamesthatshouldnotbefrozen, orNonetofreezeallthevariablesinthegraph. @paramoutput_namesNamesoftherelevantgraphoutputs. @paramclear_devicesRemovethedevicedirectivesfromthegraphforbetterportability. @returnThefrozengraphdefinition. """ fromtensorflow.python.framework.graph_utilimportconvert_variables_to_constants graph=session.graph withgraph.as_default(): freeze_var_names=list(set(v.op.nameforvintf.global_variables()).difference(keep_var_namesor[])) output_names=output_namesor[] output_names+=[v.op.nameforvintf.global_variables()] input_graph_def=graph.as_graph_def() ifclear_devices: fornodeininput_graph_def.node: node.device="" frozen_graph=convert_variables_to_constants(session,input_graph_def, output_names,freeze_var_names) returnfrozen_graph input_fld=sys.path[0] weight_file='your_model.h5' output_graph_name='tensor_model.pb' output_fld=input_fld+'/tensorflow_model/' ifnotos.path.isdir(output_fld): os.mkdir(output_fld) weight_file_path=osp.join(input_fld,weight_file) K.set_learning_phase(0) net_model=load_model(weight_file_path) print('inputis:',net_model.input.name) print('outputis:',net_model.output.name) sess=K.get_session() frozen_graph=freeze_session(K.get_session(),output_names=[net_model.output.op.name]) fromtensorflow.python.frameworkimportgraph_io graph_io.write_graph(frozen_graph,output_fld,output_graph_name,as_text=False) print('savedtheconstantgraph(readyforinference)at:',osp.join(output_fld,output_graph_name))
上面代码实现保存到当前目录的tensor_model目录下。
验证:
importtensorflowastf importnumpyasnp importPIL.ImageasImage importcv2 defrecognize(jpg_path,pb_file_path): withtf.Graph().as_default(): output_graph_def=tf.GraphDef() withopen(pb_file_path,"rb")asf: output_graph_def.ParseFromString(f.read()) tensors=tf.import_graph_def(output_graph_def,name="") printtensors withtf.Session()assess: init=tf.global_variables_initializer() sess.run(init) op=sess.graph.get_operations() forminop: print(m.values()) input_x=sess.graph.get_tensor_by_name("convolution2d_1_input:0")#具体名称看上一段代码的input.name printinput_x out_softmax=sess.graph.get_tensor_by_name("activation_4/Softmax:0")#具体名称看上一段代码的output.name printout_softmax img=cv2.imread(jpg_path,0) img_out_softmax=sess.run(out_softmax, feed_dict={input_x:1.0-np.array(img).reshape((-1,28,28,1))/255.0}) print"img_out_softmax:",img_out_softmax prediction_labels=np.argmax(img_out_softmax,axis=1) print"label:",prediction_labels pb_path='tensorflow_model/constant_graph_weights.pb' img='test/6/8_48.jpg' recognize(img,pb_path)
补充知识:如何将keras训练好的模型转换成tensorflow的.pb的文件并在TensorFlowserving环境调用
首先keras训练好的模型通过自带的model.save()保存下来是.model(.h5)格式的文件
模型载入是通过my_model=keras.models.load_model(filepath)
要将该模型转换为.pb格式的TensorFlow模型,代码如下:
#-*-coding:utf-8-*- fromkeras.layers.coreimportActivation,Dense,Flatten fromkeras.layers.embeddingsimportEmbedding fromkeras.layers.recurrentimportLSTM fromkeras.layersimportDropout fromkeras.layers.wrappersimportBidirectional fromkeras.modelsimportSequential,load_model fromkeras.preprocessingimportsequence fromsklearn.model_selectionimporttrain_test_split importcollections fromcollectionsimportdefaultdict importjieba importnumpyasnp importsys reload(sys) sys.setdefaultencoding('utf-8') importtensorflowastf importos importos.pathasosp fromkerasimportbackendasK deffreeze_session(session,keep_var_names=None,output_names=None,clear_devices=True): fromtensorflow.python.framework.graph_utilimportconvert_variables_to_constants graph=session.graph withgraph.as_default(): freeze_var_names=list(set(v.op.nameforvintf.global_variables()).difference(keep_var_namesor[])) output_names=output_namesor[] output_names+=[v.op.nameforvintf.global_variables()] input_graph_def=graph.as_graph_def() ifclear_devices: fornodeininput_graph_def.node: node.device="" frozen_graph=convert_variables_to_constants(session,input_graph_def, output_names,freeze_var_names) returnfrozen_graph input_fld='/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/' weight_file='biLSTM_brand_recognize.model' output_graph_name='tensor_model_v3.pb' output_fld=input_fld+'/tensorflow_model/' ifnotos.path.isdir(output_fld): os.mkdir(output_fld) weight_file_path=osp.join(input_fld,weight_file) K.set_learning_phase(0) net_model=load_model(weight_file_path) print('inputis:',net_model.input.name) print('outputis:',net_model.output.name) sess=K.get_session() frozen_graph=freeze_session(K.get_session(),output_names=[net_model.output.op.name]) fromtensorflow.python.frameworkimportgraph_io graph_io.write_graph(frozen_graph,output_fld,output_graph_name,as_text=True) print('savedtheconstantgraph(readyforinference)at:',osp.join(output_fld,output_graph_name))
然后模型就存成了.pb格式的文件
问题就来了,这样存下来的.pb格式的文件是frozenmodel
如果通过TensorFlowserving启用模型的话,会报错:
Etensorflow_serving/core/aspired_versions_manager.cc:358]Servable{name:mnistversion:1}cannotbeloaded:Notfound:Couldnotfindmetagraphdefmatchingsuppliedtags:{serve}.Toinspectavailabletag-setsintheSavedModel,pleaseusetheSavedModelCLI:`saved_model_cli`
因为TensorFlowserving希望读取的是savedmodel
于是需要将frozenmodel转化为savedmodel格式,解决方案如下:
fromtensorflow.python.saved_modelimportsignature_constants fromtensorflow.python.saved_modelimporttag_constants export_dir='/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/saved_model' graph_pb='/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/tensorflow_model/tensor_model.pb' builder=tf.saved_model.builder.SavedModelBuilder(export_dir) withtf.gfile.GFile(graph_pb,"rb")asf: graph_def=tf.GraphDef() graph_def.ParseFromString(f.read()) sigs={} withtf.Session(graph=tf.Graph())assess: #name=""isimportanttoensurewedon'tgetspuriousprefixing tf.import_graph_def(graph_def,name="") g=tf.get_default_graph() inp=g.get_tensor_by_name(net_model.input.name) out=g.get_tensor_by_name(net_model.output.name) sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]=\ tf.saved_model.signature_def_utils.predict_signature_def( {"in":inp},{"out":out}) builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING], signature_def_map=sigs) builder.save()
于是保存下来的savedmodel文件夹下就有两个文件:
saved_model.pbvariables
其中variables可以为空
于是将.pb模型导入serving再读取,成功!
以上这篇keras模型保存为tensorflow的二进制模型方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。