Pytorch 实现focal_loss 多类别和二分类示例
我就废话不多说了,直接上代码吧!
importnumpyasnp
importtorch
importtorch.nnasnn
importtorch.nn.functionalasF
#支持多分类和二分类
classFocalLoss(nn.Module):
"""
ThisisaimplementationofFocalLosswithsmoothlabelcrossentropysupportedwhichisproposedin
'FocalLossforDenseObjectDetection.(https://arxiv.org/abs/1708.02002)'
Focal_Loss=-1*alpha*(1-pt)^gamma*log(pt)
:paramnum_class:
:paramalpha:(tensor)3Dor4Dthescalarfactorforthiscriterion
:paramgamma:(float,double)gamma>0reducestherelativelossforwell-classifiedexamples(p>0.5)puttingmore
focusonhardmisclassifiedexample
:paramsmooth:(float,double)smoothvaluewhencrossentropy
:parambalance_index:(int)balanceclassindex,shouldbespecificwhenalphaisfloat
:paramsize_average:(bool,optional)Bydefault,thelossesareaveragedovereachlosselementinthebatch.
"""
def__init__(self,num_class,alpha=None,gamma=2,balance_index=-1,smooth=None,size_average=True):
super(FocalLoss,self).__init__()
self.num_class=num_class
self.alpha=alpha
self.gamma=gamma
self.smooth=smooth
self.size_average=size_average
ifself.alphaisNone:
self.alpha=torch.ones(self.num_class,1)
elifisinstance(self.alpha,(list,np.ndarray)):
assertlen(self.alpha)==self.num_class
self.alpha=torch.FloatTensor(alpha).view(self.num_class,1)
self.alpha=self.alpha/self.alpha.sum()
elifisinstance(self.alpha,float):
alpha=torch.ones(self.num_class,1)
alpha=alpha*(1-self.alpha)
alpha[balance_index]=self.alpha
self.alpha=alpha
else:
raiseTypeError('Notsupportalphatype')
ifself.smoothisnotNone:
ifself.smooth<0orself.smooth>1.0:
raiseValueError('smoothvalueshouldbein[0,1]')
defforward(self,input,target):
logit=F.softmax(input,dim=1)
iflogit.dim()>2:
#N,C,d1,d2->N,C,m(m=d1*d2*...)
logit=logit.view(logit.size(0),logit.size(1),-1)
logit=logit.permute(0,2,1).contiguous()
logit=logit.view(-1,logit.size(-1))
target=target.view(-1,1)
#N=input.size(0)
#alpha=torch.ones(N,self.num_class)
#alpha=alpha*(1-self.alpha)
#alpha=alpha.scatter_(1,target.long(),self.alpha)
epsilon=1e-10
alpha=self.alpha
ifalpha.device!=input.device:
alpha=alpha.to(input.device)
idx=target.cpu().long()
one_hot_key=torch.FloatTensor(target.size(0),self.num_class).zero_()
one_hot_key=one_hot_key.scatter_(1,idx,1)
ifone_hot_key.device!=logit.device:
one_hot_key=one_hot_key.to(logit.device)
ifself.smooth:
one_hot_key=torch.clamp(
one_hot_key,self.smooth,1.0-self.smooth)
pt=(one_hot_key*logit).sum(1)+epsilon
logpt=pt.log()
gamma=self.gamma
alpha=alpha[idx]
loss=-1*alpha*torch.pow((1-pt),gamma)*logpt
ifself.size_average:
loss=loss.mean()
else:
loss=loss.sum()
returnloss
classBCEFocalLoss(torch.nn.Module):
"""
二分类的Focallossalpha固定
"""
def__init__(self,gamma=2,alpha=0.25,reduction='elementwise_mean'):
super().__init__()
self.gamma=gamma
self.alpha=alpha
self.reduction=reduction
defforward(self,_input,target):
pt=torch.sigmoid(_input)
alpha=self.alpha
loss=-alpha*(1-pt)**self.gamma*target*torch.log(pt)-\
(1-alpha)*pt**self.gamma*(1-target)*torch.log(1-pt)
ifself.reduction=='elementwise_mean':
loss=torch.mean(loss)
elifself.reduction=='sum':
loss=torch.sum(loss)
returnloss
以上这篇Pytorch实现focal_loss多类别和二分类示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。