Python利用逻辑回归模型解决MNIST手写数字识别问题详解
本文实例讲述了Python利用逻辑回归模型解决MNIST手写数字识别问题。分享给大家供大家参考,具体如下:
1、MNIST手写识别问题
%matplotlibinline
importtensorflowastf
importtensorflow.examples.tutorials.mnist.input_dataasinput_data
mnist=input_data.read_data_sets('MNIST_data/',one_hot=True)
importmatplotlib.pyplotasplt
defplot_image(image):#图片显示函数
plt.imshow(image.reshape(28,28),cmap='binary')
plt.show()
print("训练集数量:",mnist.train.num_examples,
"特征值组成:",mnist.train.images.shape,
"标签组成:",mnist.train.labels.shape)
batch_images,batch_labels=mnist.train.next_batch(batch_size=10)#批量读取数据
print(batch_images.shape,batch_labels.shape)
print('标签值:',np.argmax(mnist.train.labels[1000]),end='')#np.argmax()得到实际值
print('独热编码表示:',mnist.train.labels[1000])
plot_image(mnist.train.images[1000])#显示数据集中第1000张图片
2、逻辑回归
模型的训练
importnumpyasnp
importtensorflowastf
importtensorflow.examples.tutorials.mnist.input_dataasinput_data
mnist=input_data.read_data_sets('MNIST_data/',one_hot=True)
importmatplotlib.pyplotasplt
#定义占位符、变量、前向计算
x=tf.placeholder(tf.float32,[None,784],name='x')
y=tf.placeholder(tf.float32,[None,10],name='y')
W=tf.Variable(tf.random_normal([784,10]),name='W')
b=tf.Variable(tf.zeros([10]),name='b')
forward=tf.matmul(x,W)+b
pred=tf.nn.softmax(forward)#通过softmax将线性结果分类处理
#计算预测值与真实值的匹配个数
correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
#将上一步得到的布尔值转换为浮点数,并求平均值,得到准确率
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
ss=tf.Session()
init=tf.global_variables_initializer()
ss.run(init)
#超参数设置
train_epochs=50
batch_size=100#每个批次的样本数
batch_num=int(mnist.train.num_examples/batch_size)#一轮需要训练多少批
learning_rate=0.01
#定义交叉熵损失函数、梯度下降优化器
loss_function=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
forepochinrange(train_epochs):
forbatchinrange(batch_num):#分批次读取数据进行训练
xs,ys=mnist.train.next_batch(batch_size)
ss.run(optimizer,feed_dict={x:xs,y:ys})
#每轮训练结束后通过带入验证集的数据,检测模型的损失与准去率
loss,acc=ss.run([loss_function,accuracy],\
feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
print('第%2d轮训练:损失为:%9f,准确率:%.4f'%(epoch+1,loss,acc))
结果预测
prediction=ss.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images})
defshow_result(images,labels,prediction,index,num=10):#绘制图形显示预测结果
pic=plt.gcf()#获取当前图像
pic.set_size_inches(10,12)#设置图片大小
foriinrange(0,num):
sub_pic=plt.subplot(5,5,i+1)#获取第i个子图
#将第index个images信息显示到子图上
sub_pic.imshow(np.reshape(images[index],(28,28)),cmap='binary')
title="label:"+str(np.argmax(labels[index]))#设置子图的title内容
iflen(prediction)>0:
title+=",predict:"+str(prediction[index])
sub_pic.set_title(title,fontsize=10)
sub_pic.set_xticks([])#设置x、y坐标轴不显示
sub_pic.set_yticks([])
index+=1
plt.show()
show_result(mnist.test.images,mnist.test.labels,prediction,10)
更多关于Python相关内容感兴趣的读者可查看本站专题:《Python数据结构与算法教程》、《Python加密解密算法与技巧总结》、《Python编码操作技巧总结》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》
希望本文所述对大家Python程序设计有所帮助。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。