keras模型保存为tensorflow的二进制模型方式
最近需要将使用keras训练的模型移植到手机上使用,因此需要转换到tensorflow的二进制模型。
折腾一下午,终于找到一个合适的方法,废话不多说,直接上代码:
#coding=utf-8
importsys
fromkeras.modelsimportload_model
importtensorflowastf
importos
importos.pathasosp
fromkerasimportbackendasK
deffreeze_session(session,keep_var_names=None,output_names=None,clear_devices=True):
"""
Freezesthestateofasessionintoaprunnedcomputationgraph.
Createsanewcomputationgraphwherevariablenodesarereplacedby
constantstakingtheircurrentvalueinthesession.Thenewgraphwillbe
prunnedsosubgraphsthatarenotneccesarytocomputetherequested
outputsareremoved.
@paramsessionTheTensorFlowsessiontobefrozen.
@paramkeep_var_namesAlistofvariablenamesthatshouldnotbefrozen,
orNonetofreezeallthevariablesinthegraph.
@paramoutput_namesNamesoftherelevantgraphoutputs.
@paramclear_devicesRemovethedevicedirectivesfromthegraphforbetterportability.
@returnThefrozengraphdefinition.
"""
fromtensorflow.python.framework.graph_utilimportconvert_variables_to_constants
graph=session.graph
withgraph.as_default():
freeze_var_names=list(set(v.op.nameforvintf.global_variables()).difference(keep_var_namesor[]))
output_names=output_namesor[]
output_names+=[v.op.nameforvintf.global_variables()]
input_graph_def=graph.as_graph_def()
ifclear_devices:
fornodeininput_graph_def.node:
node.device=""
frozen_graph=convert_variables_to_constants(session,input_graph_def,
output_names,freeze_var_names)
returnfrozen_graph
input_fld=sys.path[0]
weight_file='your_model.h5'
output_graph_name='tensor_model.pb'
output_fld=input_fld+'/tensorflow_model/'
ifnotos.path.isdir(output_fld):
os.mkdir(output_fld)
weight_file_path=osp.join(input_fld,weight_file)
K.set_learning_phase(0)
net_model=load_model(weight_file_path)
print('inputis:',net_model.input.name)
print('outputis:',net_model.output.name)
sess=K.get_session()
frozen_graph=freeze_session(K.get_session(),output_names=[net_model.output.op.name])
fromtensorflow.python.frameworkimportgraph_io
graph_io.write_graph(frozen_graph,output_fld,output_graph_name,as_text=False)
print('savedtheconstantgraph(readyforinference)at:',osp.join(output_fld,output_graph_name))
上面代码实现保存到当前目录的tensor_model目录下。
验证:
importtensorflowastf
importnumpyasnp
importPIL.ImageasImage
importcv2
defrecognize(jpg_path,pb_file_path):
withtf.Graph().as_default():
output_graph_def=tf.GraphDef()
withopen(pb_file_path,"rb")asf:
output_graph_def.ParseFromString(f.read())
tensors=tf.import_graph_def(output_graph_def,name="")
printtensors
withtf.Session()assess:
init=tf.global_variables_initializer()
sess.run(init)
op=sess.graph.get_operations()
forminop:
print(m.values())
input_x=sess.graph.get_tensor_by_name("convolution2d_1_input:0")#具体名称看上一段代码的input.name
printinput_x
out_softmax=sess.graph.get_tensor_by_name("activation_4/Softmax:0")#具体名称看上一段代码的output.name
printout_softmax
img=cv2.imread(jpg_path,0)
img_out_softmax=sess.run(out_softmax,
feed_dict={input_x:1.0-np.array(img).reshape((-1,28,28,1))/255.0})
print"img_out_softmax:",img_out_softmax
prediction_labels=np.argmax(img_out_softmax,axis=1)
print"label:",prediction_labels
pb_path='tensorflow_model/constant_graph_weights.pb'
img='test/6/8_48.jpg'
recognize(img,pb_path)
补充知识:如何将keras训练好的模型转换成tensorflow的.pb的文件并在TensorFlowserving环境调用
首先keras训练好的模型通过自带的model.save()保存下来是.model(.h5)格式的文件
模型载入是通过my_model=keras.models.load_model(filepath)
要将该模型转换为.pb格式的TensorFlow模型,代码如下:
#-*-coding:utf-8-*-
fromkeras.layers.coreimportActivation,Dense,Flatten
fromkeras.layers.embeddingsimportEmbedding
fromkeras.layers.recurrentimportLSTM
fromkeras.layersimportDropout
fromkeras.layers.wrappersimportBidirectional
fromkeras.modelsimportSequential,load_model
fromkeras.preprocessingimportsequence
fromsklearn.model_selectionimporttrain_test_split
importcollections
fromcollectionsimportdefaultdict
importjieba
importnumpyasnp
importsys
reload(sys)
sys.setdefaultencoding('utf-8')
importtensorflowastf
importos
importos.pathasosp
fromkerasimportbackendasK
deffreeze_session(session,keep_var_names=None,output_names=None,clear_devices=True):
fromtensorflow.python.framework.graph_utilimportconvert_variables_to_constants
graph=session.graph
withgraph.as_default():
freeze_var_names=list(set(v.op.nameforvintf.global_variables()).difference(keep_var_namesor[]))
output_names=output_namesor[]
output_names+=[v.op.nameforvintf.global_variables()]
input_graph_def=graph.as_graph_def()
ifclear_devices:
fornodeininput_graph_def.node:
node.device=""
frozen_graph=convert_variables_to_constants(session,input_graph_def,
output_names,freeze_var_names)
returnfrozen_graph
input_fld='/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/'
weight_file='biLSTM_brand_recognize.model'
output_graph_name='tensor_model_v3.pb'
output_fld=input_fld+'/tensorflow_model/'
ifnotos.path.isdir(output_fld):
os.mkdir(output_fld)
weight_file_path=osp.join(input_fld,weight_file)
K.set_learning_phase(0)
net_model=load_model(weight_file_path)
print('inputis:',net_model.input.name)
print('outputis:',net_model.output.name)
sess=K.get_session()
frozen_graph=freeze_session(K.get_session(),output_names=[net_model.output.op.name])
fromtensorflow.python.frameworkimportgraph_io
graph_io.write_graph(frozen_graph,output_fld,output_graph_name,as_text=True)
print('savedtheconstantgraph(readyforinference)at:',osp.join(output_fld,output_graph_name))
然后模型就存成了.pb格式的文件
问题就来了,这样存下来的.pb格式的文件是frozenmodel
如果通过TensorFlowserving启用模型的话,会报错:
Etensorflow_serving/core/aspired_versions_manager.cc:358]Servable{name:mnistversion:1}cannotbeloaded:Notfound:Couldnotfindmetagraphdefmatchingsuppliedtags:{serve}.Toinspectavailabletag-setsintheSavedModel,pleaseusetheSavedModelCLI:`saved_model_cli`
因为TensorFlowserving希望读取的是savedmodel
于是需要将frozenmodel转化为savedmodel格式,解决方案如下:
fromtensorflow.python.saved_modelimportsignature_constants
fromtensorflow.python.saved_modelimporttag_constants
export_dir='/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/saved_model'
graph_pb='/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/tensorflow_model/tensor_model.pb'
builder=tf.saved_model.builder.SavedModelBuilder(export_dir)
withtf.gfile.GFile(graph_pb,"rb")asf:
graph_def=tf.GraphDef()
graph_def.ParseFromString(f.read())
sigs={}
withtf.Session(graph=tf.Graph())assess:
#name=""isimportanttoensurewedon'tgetspuriousprefixing
tf.import_graph_def(graph_def,name="")
g=tf.get_default_graph()
inp=g.get_tensor_by_name(net_model.input.name)
out=g.get_tensor_by_name(net_model.output.name)
sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]=\
tf.saved_model.signature_def_utils.predict_signature_def(
{"in":inp},{"out":out})
builder.add_meta_graph_and_variables(sess,
[tag_constants.SERVING],
signature_def_map=sigs)
builder.save()
于是保存下来的savedmodel文件夹下就有两个文件:
saved_model.pbvariables
其中variables可以为空
于是将.pb模型导入serving再读取,成功!
以上这篇keras模型保存为tensorflow的二进制模型方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。