Python利用逻辑回归模型解决MNIST手写数字识别问题详解
本文实例讲述了Python利用逻辑回归模型解决MNIST手写数字识别问题。分享给大家供大家参考,具体如下:
1、MNIST手写识别问题
%matplotlibinline importtensorflowastf importtensorflow.examples.tutorials.mnist.input_dataasinput_data mnist=input_data.read_data_sets('MNIST_data/',one_hot=True) importmatplotlib.pyplotasplt defplot_image(image):#图片显示函数 plt.imshow(image.reshape(28,28),cmap='binary') plt.show() print("训练集数量:",mnist.train.num_examples, "特征值组成:",mnist.train.images.shape, "标签组成:",mnist.train.labels.shape) batch_images,batch_labels=mnist.train.next_batch(batch_size=10)#批量读取数据 print(batch_images.shape,batch_labels.shape) print('标签值:',np.argmax(mnist.train.labels[1000]),end='')#np.argmax()得到实际值 print('独热编码表示:',mnist.train.labels[1000]) plot_image(mnist.train.images[1000])#显示数据集中第1000张图片
2、逻辑回归
模型的训练
importnumpyasnp importtensorflowastf importtensorflow.examples.tutorials.mnist.input_dataasinput_data mnist=input_data.read_data_sets('MNIST_data/',one_hot=True) importmatplotlib.pyplotasplt #定义占位符、变量、前向计算 x=tf.placeholder(tf.float32,[None,784],name='x') y=tf.placeholder(tf.float32,[None,10],name='y') W=tf.Variable(tf.random_normal([784,10]),name='W') b=tf.Variable(tf.zeros([10]),name='b') forward=tf.matmul(x,W)+b pred=tf.nn.softmax(forward)#通过softmax将线性结果分类处理 #计算预测值与真实值的匹配个数 correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) #将上一步得到的布尔值转换为浮点数,并求平均值,得到准确率 accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) ss=tf.Session() init=tf.global_variables_initializer() ss.run(init) #超参数设置 train_epochs=50 batch_size=100#每个批次的样本数 batch_num=int(mnist.train.num_examples/batch_size)#一轮需要训练多少批 learning_rate=0.01 #定义交叉熵损失函数、梯度下降优化器 loss_function=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function) forepochinrange(train_epochs): forbatchinrange(batch_num):#分批次读取数据进行训练 xs,ys=mnist.train.next_batch(batch_size) ss.run(optimizer,feed_dict={x:xs,y:ys}) #每轮训练结束后通过带入验证集的数据,检测模型的损失与准去率 loss,acc=ss.run([loss_function,accuracy],\ feed_dict={x:mnist.validation.images,y:mnist.validation.labels}) print('第%2d轮训练:损失为:%9f,准确率:%.4f'%(epoch+1,loss,acc))
结果预测
prediction=ss.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images}) defshow_result(images,labels,prediction,index,num=10):#绘制图形显示预测结果 pic=plt.gcf()#获取当前图像 pic.set_size_inches(10,12)#设置图片大小 foriinrange(0,num): sub_pic=plt.subplot(5,5,i+1)#获取第i个子图 #将第index个images信息显示到子图上 sub_pic.imshow(np.reshape(images[index],(28,28)),cmap='binary') title="label:"+str(np.argmax(labels[index]))#设置子图的title内容 iflen(prediction)>0: title+=",predict:"+str(prediction[index]) sub_pic.set_title(title,fontsize=10) sub_pic.set_xticks([])#设置x、y坐标轴不显示 sub_pic.set_yticks([]) index+=1 plt.show() show_result(mnist.test.images,mnist.test.labels,prediction,10)
更多关于Python相关内容感兴趣的读者可查看本站专题:《Python数据结构与算法教程》、《Python加密解密算法与技巧总结》、《Python编码操作技巧总结》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》
希望本文所述对大家Python程序设计有所帮助。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。