tensorflow2.0保存和恢复模型3种方法
方法1:只保存模型的权重和偏置
这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。
tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。
save_weights( filepath, overwrite=True, save_format=None )
Arguments:
filepath:String,pathtothefiletosavetheweightsto.WhensavinginTensorFlowformat,thisistheprefixusedforcheckpointfiles(multiplefilesaregenerated).Notethatthe'.h5'suffixcausesweightstobesavedinHDF5format.
overwrite:Whethertosilentlyoverwriteanyexistingfileatthetargetlocation,orprovidetheuserwithamanualprompt.
save_format:Either'tf'or'h5'.Afilepathendingin'.h5'or'.keras'willdefaulttoHDF5ifsave_formatisNone.OtherwiseNonedefaultsto'tf'.
load_weights( filepath, by_name=False )
实例1:
importtensorflowastf fromtensorflowimportkeras fromtensorflow.kerasimportdatasets,layers,optimizers #step1加载训练集和测试集合 mnist=tf.keras.datasets.mnist (x_train,y_train),(x_test,y_test)=mnist.load_data() x_train,x_test=x_train/255.0,x_test/255.0 #step2创建模型 defcreate_model(): returntf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28,28)), tf.keras.layers.Dense(512,activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10,activation='softmax') ]) model=create_model() #step3编译模型主要是确定优化方法,损失函数等 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) #step4模型训练训练一个epochs model.fit(x=x_train, y=y_train, epochs=1, ) #step5模型测试 loss,acc=model.evaluate(x_test,y_test) print("trainmodel,accuracy:{:5.2f}%".format(100*acc)) #step6保存模型的权重和偏置 model.save_weights('./save_weights/my_save_weights') #step7删除模型 delmodel #step8重新创建模型 model=create_model() model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) #step9恢复权重 model.load_weights('./save_weights/my_save_weights') #step10测试模型 loss,acc=model.evaluate(x_test,y_test) print("Restoredmodel,accuracy:{:5.2f}%".format(100*acc))
trainmodel,accuracy:96.55%
Restoredmodel,accuracy:96.55%
可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。
方法2:直接保存整个模型
这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。
tf.keras.model类中的save方法和load_model方法
save( filepath, overwrite=True, include_optimizer=True, save_format=None )
Arguments:
filepath:String,pathtoSavedModelorH5filetosavethemodel.
overwrite:Whethertosilentlyoverwriteanyexistingfileatthetargetlocation,orprovidetheuserwithamanualprompt.
include_optimizer:IfTrue,saveoptimizer'sstatetogether.
save_format:Either'tf'or'h5',indicatingwhethertosavethemodeltoTensorflowSavedModelorHDF5.Thedefaultiscurrently'h5',butwillswitchto'tf'inTensorFlow2.0.The'tf'optioniscurrentlydisabled(usetf.keras.experimental.export_saved_modelinstead).
实例2:
importtensorflowastf fromtensorflowimportkeras fromtensorflow.kerasimportdatasets,layers,optimizers #step1加载训练集和测试集合 mnist=tf.keras.datasets.mnist (x_train,y_train),(x_test,y_test)=mnist.load_data() x_train,x_test=x_train/255.0,x_test/255.0 #step2创建模型 defcreate_model(): returntf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28,28)), tf.keras.layers.Dense(512,activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10,activation='softmax') ]) model=create_model() #step3编译模型主要是确定优化方法,损失函数等 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) #step4模型训练训练一个epochs model.fit(x=x_train, y=y_train, epochs=1, ) #step5模型测试 loss,acc=model.evaluate(x_test,y_test) print("trainmodel,accuracy:{:5.2f}%".format(100*acc)) #step6保存模型的权重和偏置 model.save('my_model.h5')#createsaHDF5file'my_model.h5' #step7删除模型 delmodel#deletestheexistingmodel #step8恢复模型 #returnsacompiledmodel #identicaltothepreviousone restored_model=tf.keras.models.load_model('my_model.h5') #step9测试模型 loss,acc=restored_model.evaluate(x_test,y_test) print("Restoredmodel,accuracy:{:5.2f}%".format(100*acc))
trainmodel,accuracy:96.94%
Restoredmodel,accuracy:96.94%
方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型
该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用
以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。