pytorch 计算Parameter和FLOP的操作
深度学习中,模型训练完后,查看模型的参数量和浮点计算量,在此记录下:
1THOP
在pytorch中有现成的包thop用于计算参数数量和FLOP,首先安装thop:
pipinstallthop
注意安装thop时可能出现如下错误:
解决方法:
pipinstall--upgradegit+https://github.com/Lyken17/pytorch-OpCounter.git#下载源码安装
使用方法如下:
fromtorchvision.modelsimportresnet50#引入ResNet50模型 fromthopimportprofile model=resnet50() flops,params=profile(model,input_size=(1,3,224,224))#profile(模型,输入数据)
对于自己构建的函数也一样,例如shuffleNetV2
fromthopimportprofile fromutils.ShuffleNetV2importshufflenetv2#导入shufflenet2模块 importtorch model_shuffle=shufflenetv2(width_mult=0.5) model=torch.nn.DataParallel(model_shuffle)#调用shufflenet2模型,该模型为自己定义的 flop,para=profile(model,input_size=(1,3,224,224),) print("%.2fM"%(flop/1e6),"%.2fM"%(para/1e6))
更多细节,可参考thopGitHub链接:https://github.com/Lyken17/pytorch-OpCounter
2计算参数
pytorch本身带有计算参数的方法
fromthopimportprofile fromutils.ShuffleNetV2importshufflenetv2#导入shufflenet2模块 importtorch model_shuffle=shufflenetv2(width_mult=0.5) model=torch.nn.DataParallel(model_shuffle) total=sum([param.nelement()forparaminmodel.parameters()]) print("Numberofparameter:%.2fM"%(total/1e6))
补充:pytorch:计算网络模型的计算量(FLOPs)和参数量(Params)
计算量:
FLOPs,FLOP时指浮点运算次数,s是指秒,即每秒浮点运算次数的意思,考量一个网络模型的计算量的标准。
参数量:
Params,是指网络模型中需要训练的参数总数。
第一步:安装模块(thop)
pipinstallthop
第二步:计算
importtorch fromthopimportprofile net=Model()#定义好的网络模型 input=torch.randn(1,3,112,112) flops,params=profile(net,(inputs,)) print('flops:',flops,'params:',params)
注意:
输入input的第一维度是批量(batchsize),批量的大小不回影响参数量,计算量是batch_size=1的倍数
profile(net,(inputs,))的(inputs,)中必须加上逗号,否者会报错
以上为个人经验,希望能给大家一个参考,也希望大家多多支持毛票票。如有错误或未考虑完全的地方,望不吝赐教。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。