Keras之自定义损失(loss)函数用法说明
在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:
defmy_loss(y_true,y_pred): #y_true:Truelabels.TensorFlow/Theanotensor #y_pred:Predictions.TensorFlow/Theanotensorofthesameshapeasy_true . . . returnscalar#返回一个标量值
然后在model.compile中指定即可,如:
model.compile(loss=my_loss,optimizer='sgd')
具体参考Keras官方metrics的定义keras/metrics.py:
"""Built-inmetrics.
"""
from__future__importabsolute_import
from__future__importdivision
from__future__importprint_function
importsix
from.importbackendasK
from.lossesimportmean_squared_error
from.lossesimportmean_absolute_error
from.lossesimportmean_absolute_percentage_error
from.lossesimportmean_squared_logarithmic_error
from.lossesimporthinge
from.lossesimportlogcosh
from.lossesimportsquared_hinge
from.lossesimportcategorical_crossentropy
from.lossesimportsparse_categorical_crossentropy
from.lossesimportbinary_crossentropy
from.lossesimportkullback_leibler_divergence
from.lossesimportpoisson
from.lossesimportcosine_proximity
from.utils.generic_utilsimportdeserialize_keras_object
from.utils.generic_utilsimportserialize_keras_object
defbinary_accuracy(y_true,y_pred):
returnK.mean(K.equal(y_true,K.round(y_pred)),axis=-1)
defcategorical_accuracy(y_true,y_pred):
returnK.cast(K.equal(K.argmax(y_true,axis=-1),
K.argmax(y_pred,axis=-1)),
K.floatx())
defsparse_categorical_accuracy(y_true,y_pred):
#reshapeincaseit'sinshape(num_samples,1)insteadof(num_samples,)
ifK.ndim(y_true)==K.ndim(y_pred):
y_true=K.squeeze(y_true,-1)
#convertdensepredictionstolabels
y_pred_labels=K.argmax(y_pred,axis=-1)
y_pred_labels=K.cast(y_pred_labels,K.floatx())
returnK.cast(K.equal(y_true,y_pred_labels),K.floatx())
deftop_k_categorical_accuracy(y_true,y_pred,k=5):
returnK.mean(K.in_top_k(y_pred,K.argmax(y_true,axis=-1),k),axis=-1)
defsparse_top_k_categorical_accuracy(y_true,y_pred,k=5):
#Iftheshapeofy_trueis(num_samples,1),flattento(num_samples,)
returnK.mean(K.in_top_k(y_pred,K.cast(K.flatten(y_true),'int32'),k),
axis=-1)
#Aliases
mse=MSE=mean_squared_error
mae=MAE=mean_absolute_error
mape=MAPE=mean_absolute_percentage_error
msle=MSLE=mean_squared_logarithmic_error
cosine=cosine_proximity
defserialize(metric):
returnserialize_keras_object(metric)
defdeserialize(config,custom_objects=None):
returndeserialize_keras_object(config,
module_objects=globals(),
custom_objects=custom_objects,
printable_module_name='metricfunction')
defget(identifier):
ifisinstance(identifier,dict):
config={'class_name':str(identifier),'config':{}}
returndeserialize(config)
elifisinstance(identifier,six.string_types):
returndeserialize(str(identifier))
elifcallable(identifier):
returnidentifier
else:
raiseValueError('Couldnotinterpret'
'metricfunctionidentifier:',identifier)
以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。