Pytorch实现的手写数字mnist识别功能完整示例
本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:
importtorch
importtorchvisionastv
importtorchvision.transformsastransforms
importtorch.nnasnn
importtorch.optimasoptim
importargparse
#定义是否使用GPU
device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")
#定义网络结构
classLeNet(nn.Module):
def__init__(self):
super(LeNet,self).__init__()
self.conv1=nn.Sequential(#input_size=(1*28*28)
nn.Conv2d(1,6,5,1,2),#padding=2保证输入输出尺寸相同
nn.ReLU(),#input_size=(6*28*28)
nn.MaxPool2d(kernel_size=2,stride=2),#output_size=(6*14*14)
)
self.conv2=nn.Sequential(
nn.Conv2d(6,16,5),
nn.ReLU(),#input_size=(16*10*10)
nn.MaxPool2d(2,2)#output_size=(16*5*5)
)
self.fc1=nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU()
)
self.fc2=nn.Sequential(
nn.Linear(120,84),
nn.ReLU()
)
self.fc3=nn.Linear(84,10)
#定义前向传播过程,输入为x
defforward(self,x):
x=self.conv1(x)
x=self.conv2(x)
#nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
x=x.view(x.size()[0],-1)
x=self.fc1(x)
x=self.fc2(x)
x=self.fc3(x)
returnx
#使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser=argparse.ArgumentParser()
parser.add_argument('--outf',default='./model/',help='foldertooutputimagesandmodelcheckpoints')#模型保存路径
parser.add_argument('--net',default='./model/net.pth',help="pathtonetG(tocontinuetraining)")#模型加载路径
opt=parser.parse_args()
#超参数设置
EPOCH=8#遍历数据集次数
BATCH_SIZE=64#批处理尺寸(batch_size)
LR=0.001#学习率
#定义数据预处理方式
transform=transforms.ToTensor()
#定义训练数据集
trainset=tv.datasets.MNIST(
root='./data/',
train=True,
download=True,
transform=transform)
#定义训练批处理数据
trainloader=torch.utils.data.DataLoader(
trainset,
batch_size=BATCH_SIZE,
shuffle=True,
)
#定义测试数据集
testset=tv.datasets.MNIST(
root='./data/',
train=False,
download=True,
transform=transform)
#定义测试批处理数据
testloader=torch.utils.data.DataLoader(
testset,
batch_size=BATCH_SIZE,
shuffle=False,
)
#定义损失函数lossfunction和优化方式(采用SGD)
net=LeNet().to(device)
criterion=nn.CrossEntropyLoss()#交叉熵损失函数,通常用于多分类问题上
optimizer=optim.SGD(net.parameters(),lr=LR,momentum=0.9)
#训练
if__name__=="__main__":
forepochinrange(EPOCH):
sum_loss=0.0
#数据读取
fori,datainenumerate(trainloader):
inputs,labels=data
inputs,labels=inputs.to(device),labels.to(device)
#梯度清零
optimizer.zero_grad()
#forward+backward
outputs=net(inputs)
loss=criterion(outputs,labels)
loss.backward()
optimizer.step()
#每训练100个batch打印一次平均loss
sum_loss+=loss.item()
ifi%100==99:
print('[%d,%d]loss:%.03f'
%(epoch+1,i+1,sum_loss/100))
sum_loss=0.0
#每跑完一次epoch测试一下准确率
withtorch.no_grad():
correct=0
total=0
fordataintestloader:
images,labels=data
images,labels=images.to(device),labels.to(device)
outputs=net(images)
#取得分最高的那个类
_,predicted=torch.max(outputs.data,1)
total+=labels.size(0)
correct+=(predicted==labels).sum()
print('第%d个epoch的识别准确率为:%d%%'%(epoch+1,(100*correct/total)))
#torch.save(net.state_dict(),'%s/net_%03d.pth'%(opt.outf,epoch+1))
更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》
希望本文所述对大家Python程序设计有所帮助。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。