Pytorch 实现计算分类器准确率(总分类及子分类)
分类器平均准确率计算:
correct=torch.zeros(1).squeeze().cuda() total=torch.zeros(1).squeeze().cuda() fori,(images,labels)inenumerate(train_loader): images=Variable(images.cuda()) labels=Variable(labels.cuda()) output=model(images) prediction=torch.argmax(output,1) correct+=(prediction==labels).sum().float() total+=len(labels) acc_str='Accuracy:%f'%((correct/total).cpu().detach().data.numpy())
分类器各个子类准确率计算:
correct=list(0.foriinrange(args.class_num)) total=list(0.foriinrange(args.class_num)) fori,(images,labels)inenumerate(train_loader): images=Variable(images.cuda()) labels=Variable(labels.cuda()) output=model(images) prediction=torch.argmax(output,1) res=prediction==labels forlabel_idxinrange(len(labels)): label_single=label[label_idx] correct[label_single]+=res[label_idx].item() total[label_single]+=1 acc_str='Accuracy:%f'%(sum(correct)/sum(total)) foracc_idxinrange(len(train_class_correct)): try: acc=correct[acc_idx]/total[acc_idx] except: acc=0 finally: acc_str+='\tclassID:%d\tacc:%f\t'%(acc_idx+1,acc)
以上这篇Pytorch实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。