关于ResNeXt网络的pytorch实现
此处需要pipinstallpretrainedmodels
"""
FinetuningTorchvisionModels
"""
from__future__importprint_function
from__future__importdivision
importtorch
importtorch.nnasnn
importtorch.optimasoptim
importnumpyasnp
importtorchvision
fromtorchvisionimportdatasets,models,transforms
importmatplotlib.pyplotasplt
importtime
importos
importcopy
importargparse
importpretrainedmodels.models.resnextasresnext
print("PyTorchVersion:",torch.__version__)
print("TorchvisionVersion:",torchvision.__version__)
#Topleveldatadirectory.Hereweassumetheformatofthedirectoryconforms
#totheImageFolderstructure
#data_dir="./data/hymenoptera_data"
data_dir="/media/dell/dell/data/13/"
#Modelstochoosefrom[resnet,alexnet,vgg,squeezenet,densenet,inception]
model_name="resnext"
#Numberofclassesinthedataset
num_classes=171
#Batchsizefortraining(changedependingonhowmuchmemoryyouhave)
batch_size=16
#Numberofepochstotrainfor
num_epochs=1000
#Flagforfeatureextracting.WhenFalse,wefinetunethewholemodel,
#whenTrueweonlyupdatethereshapedlayerparams
feature_extract=False
#参数设置,使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser=argparse.ArgumentParser(description='PyTorchseresnet')
parser.add_argument('--outf',default='/home/dell/Desktop/zhou/train7',help='foldertooutputimagesandmodelcheckpoints')#输出结果保存路径
parser.add_argument('--net',default='/home/dell/Desktop/zhou/train7/resnext.pth',help="pathtonet(tocontinuetraining)")#恢复训练时的模型路径
args=parser.parse_args()
deftrain_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False):
#deftrain_model(model,dataloaders,criterion,optimizer,num_epochs=25,scheduler,is_inception=False):
since=time.time()
val_acc_history=[]
best_model_wts=copy.deepcopy(model.state_dict())
best_acc=0.0
print("StartTraining,resnext!")#定义遍历数据集的次数
withopen("/home/dell/Desktop/zhou/train7/acc.txt","w")asf1:
withopen("/home/dell/Desktop/zhou/train7/log.txt","w")asf2:
forepochinrange(num_epochs):
print('Epoch{}/{}'.format(epoch+1,num_epochs))
print('*'*10)
#Eachepochhasatrainingandvalidationphase
forphasein['train','val']:
ifphase=='train':
#scheduler.step()
model.train()#Setmodeltotrainingmode
else:
model.eval()#Setmodeltoevaluatemode
running_loss=0.0
running_corrects=0
#Iterateoverdata.
forinputs,labelsindataloaders[phase]:
inputs=inputs.to(device)
labels=labels.to(device)
#zerotheparametergradients
optimizer.zero_grad()
#forward
#trackhistoryifonlyintrain
withtorch.set_grad_enabled(phase=='train'):
#Getmodeloutputsandcalculateloss
#Specialcaseforinceptionbecauseintrainingithasanauxiliaryoutput.Intrain
#modewecalculatethelossbysummingthefinaloutputandtheauxiliaryoutput
#butintestingweonlyconsiderthefinaloutput.
ifis_inceptionandphase=='train':
#Fromhttps://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
outputs,aux_outputs=model(inputs)
loss1=criterion(outputs,labels)
loss2=criterion(aux_outputs,labels)
loss=loss1+0.4*loss2
else:
outputs=model(inputs)
loss=criterion(outputs,labels)
_,preds=torch.max(outputs,1)
#backward+optimizeonlyifintrainingphase
ifphase=='train':
loss.backward()
optimizer.step()
#statistics
running_loss+=loss.item()*inputs.size(0)
running_corrects+=torch.sum(preds==labels.data)
epoch_loss=running_loss/len(dataloaders[phase].dataset)
epoch_acc=running_corrects.double()/len(dataloaders[phase].dataset)
print('{}Loss:{:.4f}Acc:{:.4f}'.format(phase,epoch_loss,epoch_acc))
f2.write('{}Loss:{:.4f}Acc:{:.4f}'.format(phase,epoch_loss,epoch_acc))
f2.write('\n')
f2.flush()
#deepcopythemodel
ifphase=='val':
if(epoch+1)%5==0:
#print('Savingmodel......')
torch.save(model.state_dict(),'%s/inception_%03d.pth'%(args.outf,epoch+1))
f1.write("EPOCH=%03d,Accuracy=%.3f%%"%(epoch+1,100*epoch_acc))
f1.write('\n')
f1.flush()
ifphase=='val'andepoch_acc>best_acc:
f3=open("/home/dell/Desktop/zhou/train7/best_acc.txt","w")
f3.write("EPOCH=%d,best_acc=%.3f%%"%(epoch+1,100*epoch_acc))
f3.close()
best_acc=epoch_acc
best_model_wts=copy.deepcopy(model.state_dict())
ifphase=='val':
val_acc_history.append(epoch_acc)
time_elapsed=time.time()-since
print('Trainingcompletein{:.0f}m{:.0f}s'.format(time_elapsed//60,time_elapsed%60))
print('BestvalAcc:{:4f}'.format(best_acc))
#loadbestmodelweights
model.load_state_dict(best_model_wts)
returnmodel,val_acc_history
defset_parameter_requires_grad(model,feature_extracting):
iffeature_extracting:
forparaminmodel.parameters():
param.requires_grad=False
definitialize_model(model_name,num_classes,feature_extract,use_pretrained=True):
#Initializethesevariableswhichwillbesetinthisifstatement.Eachofthese
#variablesismodelspecific.
model_ft=None
input_size=0
ifmodel_name=="resnet":
"""Resnet18
"""
model_ft=models.resnet18(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft,feature_extract)
num_ftrs=model_ft.fc.in_features
model_ft.fc=nn.Linear(num_ftrs,num_classes)
input_size=224
elifmodel_name=="alexnet":
"""Alexnet
"""
model_ft=models.alexnet(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft,feature_extract)
num_ftrs=model_ft.classifier[6].in_features
model_ft.classifier[6]=nn.Linear(num_ftrs,num_classes)
input_size=224
elifmodel_name=="vgg":
"""VGG11_bn
"""
model_ft=models.vgg11_bn(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft,feature_extract)
num_ftrs=model_ft.classifier[6].in_features
model_ft.classifier[6]=nn.Linear(num_ftrs,num_classes)
input_size=224
elifmodel_name=="squeezenet":
"""Squeezenet
"""
model_ft=models.squeezenet1_0(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft,feature_extract)
model_ft.classifier[1]=nn.Conv2d(512,num_classes,kernel_size=(1,1),stride=(1,1))
model_ft.num_classes=num_classes
input_size=224
elifmodel_name=="densenet":
"""Densenet
"""
model_ft=models.densenet121(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft,feature_extract)
num_ftrs=model_ft.classifier.in_features
model_ft.classifier=nn.Linear(num_ftrs,num_classes)
input_size=224
elifmodel_name=="resnext":
"""resnext
Becareful,expects(3,224,224)sizedimages
"""
model_ft=resnext.resnext101_64x4d(num_classes=1000,pretrained='imagenet')
set_parameter_requires_grad(model_ft,feature_extract)
model_ft.last_linear=nn.Linear(2048,num_classes)
#pre='/home/dell/Desktop/zhou/train6/inception_009.pth'
#model_ft.load_state_dict(torch.load(pre))
input_size=224
else:
print("Invalidmodelname,exiting...")
exit()
returnmodel_ft,input_size
#Initializethemodelforthisrun
model_ft,input_size=initialize_model(model_name,num_classes,feature_extract,use_pretrained=True)
#Printthemodelwejustinstantiated
#print(model_ft)
data_transforms={
'train':transforms.Compose([
transforms.RandomResizedCrop(input_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
]),
'val':transforms.Compose([
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
]),
}
print("InitializingDatasetsandDataloaders...")
#Createtrainingandvalidationdatasets
image_datasets={x:datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x])forxin['train','val']}
#Createtrainingandvalidationdataloaders
dataloaders_dict={x:torch.utils.data.DataLoader(image_datasets[x],batch_size=batch_size,shuffle=True,num_workers=4)forxin['train','val']}
#DetectifwehaveaGPUavailable
device=torch.device("cuda:1"iftorch.cuda.is_available()else"cpu")
#we='/home/dell/Desktop/dj/inception_050.pth'
#model_ft.load_state_dict(torch.load(we))#diaoyong
#SendthemodeltoGPU
model_ft=model_ft.to(device)
params_to_update=model_ft.parameters()
print("Paramstolearn:")
iffeature_extract:
params_to_update=[]
forname,paraminmodel_ft.named_parameters():
ifparam.requires_grad==True:
params_to_update.append(param)
print("\t",name)
else:
forname,paraminmodel_ft.named_parameters():
ifparam.requires_grad==True:
print("\t",name)
#Observethatallparametersarebeingoptimized
optimizer_ft=optim.SGD(params_to_update,lr=0.01,momentum=0.9)
#DecayLRbyafactorof0.1every7epochs
#exp_lr_scheduler=lr_scheduler.StepLR(optimizer_ft,step_size=30,gamma=0.95)
#Setupthelossfxn
criterion=nn.CrossEntropyLoss()
print(model_ft)
#Trainandevaluate
model_ft,hist=train_model(model_ft,dataloaders_dict,criterion,optimizer_ft,num_epochs=num_epochs,is_inception=False)
以上这篇关于ResNeXt网络的pytorch实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。