tensorflow2.0保存和恢复模型3种方法
方法1:只保存模型的权重和偏置
这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。
tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。
save_weights( filepath, overwrite=True, save_format=None )
Arguments:
filepath:String,pathtothefiletosavetheweightsto.WhensavinginTensorFlowformat,thisistheprefixusedforcheckpointfiles(multiplefilesaregenerated).Notethatthe'.h5'suffixcausesweightstobesavedinHDF5format.
overwrite:Whethertosilentlyoverwriteanyexistingfileatthetargetlocation,orprovidetheuserwithamanualprompt.
save_format:Either'tf'or'h5'.Afilepathendingin'.h5'or'.keras'willdefaulttoHDF5ifsave_formatisNone.OtherwiseNonedefaultsto'tf'.
load_weights( filepath, by_name=False )
实例1:
importtensorflowastf
fromtensorflowimportkeras
fromtensorflow.kerasimportdatasets,layers,optimizers
#step1加载训练集和测试集合
mnist=tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data()
x_train,x_test=x_train/255.0,x_test/255.0
#step2创建模型
defcreate_model():
returntf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28,28)),
tf.keras.layers.Dense(512,activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10,activation='softmax')
])
model=create_model()
#step3编译模型主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
#step4模型训练训练一个epochs
model.fit(x=x_train,
y=y_train,
epochs=1,
)
#step5模型测试
loss,acc=model.evaluate(x_test,y_test)
print("trainmodel,accuracy:{:5.2f}%".format(100*acc))
#step6保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
#step7删除模型
delmodel
#step8重新创建模型
model=create_model()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
#step9恢复权重
model.load_weights('./save_weights/my_save_weights')
#step10测试模型
loss,acc=model.evaluate(x_test,y_test)
print("Restoredmodel,accuracy:{:5.2f}%".format(100*acc))
trainmodel,accuracy:96.55%
Restoredmodel,accuracy:96.55%
可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。
方法2:直接保存整个模型
这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。
tf.keras.model类中的save方法和load_model方法
save( filepath, overwrite=True, include_optimizer=True, save_format=None )
Arguments:
filepath:String,pathtoSavedModelorH5filetosavethemodel.
overwrite:Whethertosilentlyoverwriteanyexistingfileatthetargetlocation,orprovidetheuserwithamanualprompt.
include_optimizer:IfTrue,saveoptimizer'sstatetogether.
save_format:Either'tf'or'h5',indicatingwhethertosavethemodeltoTensorflowSavedModelorHDF5.Thedefaultiscurrently'h5',butwillswitchto'tf'inTensorFlow2.0.The'tf'optioniscurrentlydisabled(usetf.keras.experimental.export_saved_modelinstead).
实例2:
importtensorflowastf
fromtensorflowimportkeras
fromtensorflow.kerasimportdatasets,layers,optimizers
#step1加载训练集和测试集合
mnist=tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data()
x_train,x_test=x_train/255.0,x_test/255.0
#step2创建模型
defcreate_model():
returntf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28,28)),
tf.keras.layers.Dense(512,activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10,activation='softmax')
])
model=create_model()
#step3编译模型主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
#step4模型训练训练一个epochs
model.fit(x=x_train,
y=y_train,
epochs=1,
)
#step5模型测试
loss,acc=model.evaluate(x_test,y_test)
print("trainmodel,accuracy:{:5.2f}%".format(100*acc))
#step6保存模型的权重和偏置
model.save('my_model.h5')#createsaHDF5file'my_model.h5'
#step7删除模型
delmodel#deletestheexistingmodel
#step8恢复模型
#returnsacompiledmodel
#identicaltothepreviousone
restored_model=tf.keras.models.load_model('my_model.h5')
#step9测试模型
loss,acc=restored_model.evaluate(x_test,y_test)
print("Restoredmodel,accuracy:{:5.2f}%".format(100*acc))
trainmodel,accuracy:96.94%
Restoredmodel,accuracy:96.94%
方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型
该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用
以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。