使用PyTorch实现MNIST手写体识别代码
实验环境
win10+anaconda+jupyternotebook
Pytorch1.1.0
Python3.7
gpu环境(可选)
MNIST数据集介绍
MNIST包括6万张28x28的训练样本,1万张测试样本,可以说是CV里的“HelloWord”。本文使用的CNN网络将MNIST数据的识别率提高到了99%。下面我们就开始进行实战。
导入包
importtorch importtorch.nnasnn importtorch.nn.functionalasF importtorch.optimasoptim fromtorchvisionimportdatasets,transforms torch.__version__
定义超参数
BATCH_SIZE=512 EPOCHS=20 DEVICE=torch.device("cuda"iftorch.cuda.is_available()else"cpu")
数据集
我们直接使用PyTorch中自带的dataset,并使用DataLoader对训练数据和测试数据分别进行读取。如果下载过数据集这里download可选择False
train_loader=torch.utils.data.DataLoader( datasets.MNIST('data',train=True,download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,)) ])), batch_size=BATCH_SIZE,shuffle=True) test_loader=torch.utils.data.DataLoader( datasets.MNIST('data',train=False,transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,)) ])), batch_size=BATCH_SIZE,shuffle=True)
定义网络
该网络包括两个卷积层和两个线性层,最后输出10个维度,即代表0-9十个数字。
classConvNet(nn.Module): def__init__(self): super().__init__() self.conv1=nn.Conv2d(1,10,5)#input:(1,28,28)output:(10,24,24) self.conv2=nn.Conv2d(10,20,3)#input:(10,12,12)output:(20,10,10) self.fc1=nn.Linear(20*10*10,500) self.fc2=nn.Linear(500,10) defforward(self,x): in_size=x.size(0) out=self.conv1(x) out=F.relu(out) out=F.max_pool2d(out,2,2) out=self.conv2(out) out=F.relu(out) out=out.view(in_size,-1) out=self.fc1(out) out=F.relu(out) out=self.fc2(out) out=F.log_softmax(out,dim=1) returnout
实例化网络
model=ConvNet().to(DEVICE)#将网络移动到gpu上 optimizer=optim.Adam(model.parameters())#使用Adam优化器
定义训练函数
deftrain(model,device,train_loader,optimizer,epoch): model.train() forbatch_idx,(data,target)inenumerate(train_loader): data,target=data.to(device),target.to(device) optimizer.zero_grad() output=model(data) loss=F.nll_loss(output,target) loss.backward() optimizer.step() if(batch_idx+1)%30==0: print('TrainEpoch:{}[{}/{}({:.0f}%)]\tLoss:{:.6f}'.format( epoch,batch_idx*len(data),len(train_loader.dataset), 100.*batch_idx/len(train_loader),loss.item()))
定义测试函数
deftest(model,device,test_loader): model.eval() test_loss=0 correct=0 withtorch.no_grad(): fordata,targetintest_loader: data,target=data.to(device),target.to(device) output=model(data) test_loss+=F.nll_loss(output,target,reduction='sum').item()#将一批的损失相加 pred=output.max(1,keepdim=True)[1]#找到概率最大的下标 correct+=pred.eq(target.view_as(pred)).sum().item() test_loss/=len(test_loader.dataset) print('\nTestset:Averageloss:{:.4f},Accuracy:{}/{}({:.0f}%)\n'.format( test_loss,correct,len(test_loader.dataset), 100.*correct/len(test_loader.dataset)))
开始训练
forepochinrange(1,EPOCHS+1): train(model,DEVICE,train_loader,optimizer,epoch) test(model,DEVICE,test_loader)
实验结果
TrainEpoch:1[14848/60000(25%)] Loss:0.375058 TrainEpoch:1[30208/60000(50%)] Loss:0.255248 TrainEpoch:1[45568/60000(75%)] Loss:0.128060 Testset:Averageloss:0.0992,Accuracy:9690/10000(97%) TrainEpoch:2[14848/60000(25%)] Loss:0.093066 TrainEpoch:2[30208/60000(50%)] Loss:0.087888 TrainEpoch:2[45568/60000(75%)] Loss:0.068078 Testset:Averageloss:0.0599,Accuracy:9816/10000(98%) TrainEpoch:3[14848/60000(25%)] Loss:0.043926 TrainEpoch:3[30208/60000(50%)] Loss:0.037321 TrainEpoch:3[45568/60000(75%)] Loss:0.068404 Testset:Averageloss:0.0416,Accuracy:9859/10000(99%) TrainEpoch:4[14848/60000(25%)] Loss:0.031654 TrainEpoch:4[30208/60000(50%)] Loss:0.041341 TrainEpoch:4[45568/60000(75%)] Loss:0.036493 Testset:Averageloss:0.0361,Accuracy:9873/10000(99%) TrainEpoch:5[14848/60000(25%)] Loss:0.027688 TrainEpoch:5[30208/60000(50%)] Loss:0.019488 TrainEpoch:5[45568/60000(75%)] Loss:0.018023 Testset:Averageloss:0.0344,Accuracy:9875/10000(99%) TrainEpoch:6[14848/60000(25%)] Loss:0.024212 TrainEpoch:6[30208/60000(50%)] Loss:0.018689 TrainEpoch:6[45568/60000(75%)] Loss:0.040412 Testset:Averageloss:0.0350,Accuracy:9879/10000(99%) TrainEpoch:7[14848/60000(25%)] Loss:0.030426 TrainEpoch:7[30208/60000(50%)] Loss:0.026939 TrainEpoch:7[45568/60000(75%)] Loss:0.010722 Testset:Averageloss:0.0287,Accuracy:9892/10000(99%) TrainEpoch:8[14848/60000(25%)] Loss:0.021109 TrainEpoch:8[30208/60000(50%)] Loss:0.034845 TrainEpoch:8[45568/60000(75%)] Loss:0.011223 Testset:Averageloss:0.0299,Accuracy:9904/10000(99%) TrainEpoch:9[14848/60000(25%)] Loss:0.011391 TrainEpoch:9[30208/60000(50%)] Loss:0.008091 TrainEpoch:9[45568/60000(75%)] Loss:0.039870 Testset:Averageloss:0.0341,Accuracy:9890/10000(99%) TrainEpoch:10[14848/60000(25%)] Loss:0.026813 TrainEpoch:10[30208/60000(50%)] Loss:0.011159 TrainEpoch:10[45568/60000(75%)] Loss:0.024884 Testset:Averageloss:0.0286,Accuracy:9901/10000(99%) TrainEpoch:11[14848/60000(25%)] Loss:0.006420 TrainEpoch:11[30208/60000(50%)] Loss:0.003641 TrainEpoch:11[45568/60000(75%)] Loss:0.003402 Testset:Averageloss:0.0377,Accuracy:9894/10000(99%) TrainEpoch:12[14848/60000(25%)] Loss:0.006866 TrainEpoch:12[30208/60000(50%)] Loss:0.012617 TrainEpoch:12[45568/60000(75%)] Loss:0.008548 Testset:Averageloss:0.0311,Accuracy:9908/10000(99%) TrainEpoch:13[14848/60000(25%)] Loss:0.010539 TrainEpoch:13[30208/60000(50%)] Loss:0.002952 TrainEpoch:13[45568/60000(75%)] Loss:0.002313 Testset:Averageloss:0.0293,Accuracy:9905/10000(99%) TrainEpoch:14[14848/60000(25%)] Loss:0.002100 TrainEpoch:14[30208/60000(50%)] Loss:0.000779 TrainEpoch:14[45568/60000(75%)] Loss:0.005952 Testset:Averageloss:0.0335,Accuracy:9897/10000(99%) TrainEpoch:15[14848/60000(25%)] Loss:0.006053 TrainEpoch:15[30208/60000(50%)] Loss:0.002559 TrainEpoch:15[45568/60000(75%)] Loss:0.002555 Testset:Averageloss:0.0357,Accuracy:9894/10000(99%) TrainEpoch:16[14848/60000(25%)] Loss:0.000895 TrainEpoch:16[30208/60000(50%)] Loss:0.004923 TrainEpoch:16[45568/60000(75%)] Loss:0.002339 Testset:Averageloss:0.0400,Accuracy:9893/10000(99%) TrainEpoch:17[14848/60000(25%)] Loss:0.004136 TrainEpoch:17[30208/60000(50%)] Loss:0.000927 TrainEpoch:17[45568/60000(75%)] Loss:0.002084 Testset:Averageloss:0.0353,Accuracy:9895/10000(99%) TrainEpoch:18[14848/60000(25%)] Loss:0.004508 TrainEpoch:18[30208/60000(50%)] Loss:0.001272 TrainEpoch:18[45568/60000(75%)] Loss:0.000543 Testset:Averageloss:0.0380,Accuracy:9894/10000(99%) TrainEpoch:19[14848/60000(25%)] Loss:0.001699 TrainEpoch:19[30208/60000(50%)] Loss:0.000661 TrainEpoch:19[45568/60000(75%)] Loss:0.000275 Testset:Averageloss:0.0339,Accuracy:9905/10000(99%) TrainEpoch:20[14848/60000(25%)] Loss:0.000441 TrainEpoch:20[30208/60000(50%)] Loss:0.000695 TrainEpoch:20[45568/60000(75%)] Loss:0.000467 Testset:Averageloss:0.0396,Accuracy:9894/10000(99%)
总结
一个实际项目的工作流程:找到数据集,对数据做预处理,定义我们的模型,调整超参数,测试训练,再通过训练结果对超参数进行调整或者对模型进行调整。
以上这篇使用PyTorch实现MNIST手写体识别代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。