pytorch 利用lstm做mnist手写数字识别分类的实例
代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。
#-*-coding:utf-8-*- """ CreatedonTueOct908:53:252018 @author:www """ importsys sys.path.append('..') importtorch importdatetime fromtorch.autogradimportVariable fromtorchimportnn fromtorch.utils.dataimportDataLoader fromtorchvisionimporttransformsastfs fromtorchvision.datasetsimportMNIST #定义数据 data_tf=tfs.Compose([ tfs.ToTensor(), tfs.Normalize([0.5],[0.5]) ]) train_set=MNIST('E:/data',train=True,transform=data_tf,download=True) test_set=MNIST('E:/data',train=False,transform=data_tf,download=True) train_data=DataLoader(train_set,64,True,num_workers=4) test_data=DataLoader(test_set,128,False,num_workers=4) #定义模型 classrnn_classify(nn.Module): def__init__(self,in_feature=28,hidden_feature=100,num_class=10,num_layers=2): super(rnn_classify,self).__init__() self.rnn=nn.LSTM(in_feature,hidden_feature,num_layers)#使用两层lstm self.classifier=nn.Linear(hidden_feature,num_class)#将最后一个的rnn使用全连接的到最后的输出结果 defforward(self,x): #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28) x=x.squeeze()#去掉(batch,1,28,28)中的1,变成(batch,28,28) x=x.permute(2,0,1)#将最后一维放到第一维,变成(batch,28,28) out,_=self.rnn(x)#使用默认的隐藏状态,得到的out是(28,batch,hidden_feature) out=out[-1,:,:]#取序列中的最后一个,大小是(batch,hidden_feature) out=self.classifier(out)#得到分类结果 returnout net=rnn_classify() criterion=nn.CrossEntropyLoss() optimizer=torch.optim.Adadelta(net.parameters(),1e-1) #定义训练过程 defget_acc(output,label): total=output.shape[0] _,pred_label=output.max(1) num_correct=(pred_label==label).sum().item() returnnum_correct/total deftrain(net,train_data,valid_data,num_epochs,optimizer,criterion): iftorch.cuda.is_available(): net=net.cuda() prev_time=datetime.datetime.now() forepochinrange(num_epochs): train_loss=0 train_acc=0 net=net.train() forim,labelintrain_data: iftorch.cuda.is_available(): im=Variable(im.cuda())#(bs,3,h,w) label=Variable(label.cuda())#(bs,h,w) else: im=Variable(im) label=Variable(label) #forward output=net(im) loss=criterion(output,label) #backward optimizer.zero_grad() loss.backward() optimizer.step() train_loss+=loss.item() train_acc+=get_acc(output,label) cur_time=datetime.datetime.now() h,remainder=divmod((cur_time-prev_time).seconds,3600) m,s=divmod(remainder,60) time_str="Time%02d:%02d:%02d"%(h,m,s) ifvalid_dataisnotNone: valid_loss=0 valid_acc=0 net=net.eval() forim,labelinvalid_data: iftorch.cuda.is_available(): im=Variable(im.cuda()) label=Variable(label.cuda()) else: im=Variable(im) label=Variable(label) output=net(im) loss=criterion(output,label) valid_loss+=loss.item() valid_acc+=get_acc(output,label) epoch_str=( "Epoch%d.TrainLoss:%f,TrainAcc:%f,ValidLoss:%f,ValidAcc:%f," %(epoch,train_loss/len(train_data), train_acc/len(train_data),valid_loss/len(valid_data), valid_acc/len(valid_data))) else: epoch_str=("Epoch%d.TrainLoss:%f,TrainAcc:%f,"% (epoch,train_loss/len(train_data), train_acc/len(train_data))) prev_time=cur_time print(epoch_str+time_str) train(net,train_data,test_data,10,optimizer,criterion)
以上这篇pytorch利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。