pytorch实现mnist分类的示例讲解
torchvision包包含了目前流行的数据集,模型结构和常用的图片转换工具。
torchvision.datasets中包含了以下数据集
MNIST
COCO(用于图像标注和目标检测)(CaptioningandDetection)
LSUNClassification
ImageFolder
Imagenet-12
CIFAR10andCIFAR100
STL10
torchvision.models
torchvision.models模块的子模块中包含以下模型结构。
AlexNet
VGG
ResNet
SqueezeNet
DenseNetYoucanconstructamodelwithrandomweightsbycallingitsconstructor:
pytorchtorchvisiontransform
对PIL.Image进行变换
from__future__importprint_function importargparse#Python命令行解析工具 importtorch importtorch.nnasnn importtorch.nn.functionalasF importtorch.optimasoptim fromtorchvisionimportdatasets,transforms classNet(nn.Module): def__init__(self): super(Net,self).__init__() self.conv1=nn.Conv2d(1,10,kernel_size=5) self.conv2=nn.Conv2d(10,20,kernel_size=5) self.conv2_drop=nn.Dropout2d() self.fc1=nn.Linear(320,50) self.fc2=nn.Linear(50,10) defforward(self,x): x=F.relu(F.max_pool2d(self.conv1(x),2)) x=F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2)) x=x.view(-1,320) x=F.relu(self.fc1(x)) x=F.dropout(x,training=self.training) x=self.fc2(x) returnF.log_softmax(x,dim=1) deftrain(args,model,device,train_loader,optimizer,epoch): model.train() forbatch_idx,(data,target)inenumerate(train_loader): data,target=data.to(device),target.to(device) optimizer.zero_grad() output=model(data) loss=F.nll_loss(output,target) loss.backward() optimizer.step() ifbatch_idx%args.log_interval==0: print('TrainEpoch:{}[{}/{}({:.0f}%)]\tLoss:{:.6f}'.format( epoch,batch_idx*len(data),len(train_loader.dataset), 100.*batch_idx/len(train_loader),loss.item())) deftest(args,model,device,test_loader): model.eval() test_loss=0 correct=0 withtorch.no_grad(): fordata,targetintest_loader: data,target=data.to(device),target.to(device) output=model(data) test_loss+=F.nll_loss(output,target,size_average=False).item()#sumupbatchloss pred=output.max(1,keepdim=True)[1]#gettheindexofthemaxlog-probability correct+=pred.eq(target.view_as(pred)).sum().item() test_loss/=len(test_loader.dataset) print('\nTestset:Averageloss:{:.4f},Accuracy:{}/{}({:.0f}%)\n'.format( test_loss,correct,len(test_loader.dataset), 100.*correct/len(test_loader.dataset))) defmain(): #Trainingsettings parser=argparse.ArgumentParser(description='PyTorchMNISTExample') parser.add_argument('--batch-size',type=int,default=64,metavar='N', help='inputbatchsizefortraining(default:64)') parser.add_argument('--test-batch-size',type=int,default=1000,metavar='N', help='inputbatchsizefortesting(default:1000)') parser.add_argument('--epochs',type=int,default=10,metavar='N', help='numberofepochstotrain(default:10)') parser.add_argument('--lr',type=float,default=0.01,metavar='LR', help='learningrate(default:0.01)') parser.add_argument('--momentum',type=float,default=0.5,metavar='M', help='SGDmomentum(default:0.5)') parser.add_argument('--no-cuda',action='store_true',default=False, help='disablesCUDAtraining') parser.add_argument('--seed',type=int,default=1,metavar='S', help='randomseed(default:1)') parser.add_argument('--log-interval',type=int,default=10,metavar='N', help='howmanybatchestowaitbeforeloggingtrainingstatus') args=parser.parse_args() use_cuda=notargs.no_cudaandtorch.cuda.is_available() torch.manual_seed(args.seed) device=torch.device("cuda"ifuse_cudaelse"cpu") kwargs={'num_workers':1,'pin_memory':True}ifuse_cudaelse{} train_loader=torch.utils.data.DataLoader( datasets.MNIST('../data',train=True,download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,)) ])), batch_size=args.batch_size,shuffle=True,**kwargs) test_loader=torch.utils.data.DataLoader( datasets.MNIST('../data',train=False,transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,)) ])), batch_size=args.test_batch_size,shuffle=True,**kwargs) model=Net().to(device) optimizer=optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum) forepochinrange(1,args.epochs+1): train(args,model,device,train_loader,optimizer,epoch) test(args,model,device,test_loader) if__name__=='__main__': main()
以上这篇pytorch实现mnist分类的示例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。