Tensorflow训练MNIST手写数字识别模型
本文实例为大家分享了Tensorflow训练MNIST手写数字识别模型的具体代码,供大家参考,具体内容如下
importtensorflowastf
fromtensorflow.examples.tutorials.mnistimportinput_data
INPUT_NODE=784#输入层节点=图片像素=28x28=784
OUTPUT_NODE=10#输出层节点数=图片类别数目
LAYER1_NODE=500#隐藏层节点数,只有一个隐藏层
BATCH_SIZE=100#一个训练包中的数据个数,数字越小
#越接近随机梯度下降,越大越接近梯度下降
LEARNING_RATE_BASE=0.8#基础学习率
LEARNING_RATE_DECAY=0.99#学习率衰减率
REGULARIZATION_RATE=0.0001#正则化项系数
TRAINING_STEPS=30000#训练轮数
MOVING_AVG_DECAY=0.99#滑动平均衰减率
#定义一个辅助函数,给定神经网络的输入和所有参数,计算神经网络的前向传播结果
definference(input_tensor,avg_class,weights1,biases1,
weights2,biases2):
#当没有提供滑动平均类时,直接使用参数当前取值
ifavg_class==None:
#计算隐藏层前向传播结果
layer1=tf.nn.relu(tf.matmul(input_tensor,weights1)+biases1)
#计算输出层前向传播结果
returntf.matmul(layer1,weights2)+biases2
else:
#首先计算变量的滑动平均值,然后计算前向传播结果
layer1=tf.nn.relu(
tf.matmul(input_tensor,avg_class.average(weights1))+
avg_class.average(biases1))
returntf.matmul(
layer1,avg_class.average(weights2))+avg_class.average(biases2)
#训练模型的过程
deftrain(mnist):
x=tf.placeholder(tf.float32,[None,INPUT_NODE],name='x-input')
y_=tf.placeholder(tf.float32,[None,OUTPUT_NODE],name='y-input')
#生成隐藏层参数
weights1=tf.Variable(
tf.truncated_normal([INPUT_NODE,LAYER1_NODE],stddev=0.1))
biases1=tf.Variable(tf.constant(0.1,shape=[LAYER1_NODE]))
#生成输出层参数
weights2=tf.Variable(
tf.truncated_normal([LAYER1_NODE,OUTPUT_NODE],stddev=0.1))
biases2=tf.Variable(tf.constant(0.1,shape=[OUTPUT_NODE]))
#计算前向传播结果,不使用参数滑动平均值avg_class=None
y=inference(x,None,weights1,biases1,weights2,biases2)
#定义训练轮数变量,指定为不可训练
global_step=tf.Variable(0,trainable=False)
#给定滑动平均衰减率和训练轮数的变量,初始化滑动平均类
variable_avgs=tf.train.ExponentialMovingAverage(
MOVING_AVG_DECAY,global_step)
#在所有代表神经网络参数的可训练变量上使用滑动平均
variables_avgs_op=variable_avgs.apply(tf.trainable_variables())
#计算使用滑动平均值后的前向传播结果
avg_y=inference(x,variable_avgs,weights1,biases1,weights2,biases2)
#计算交叉熵作为损失函数
cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=y,labels=tf.argmax(y_,1))
cross_entropy_mean=tf.reduce_mean(cross_entropy)
#计算L2正则化损失函数
regularizer=tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
regularization=regularizer(weights1)+regularizer(weights2)
loss=cross_entropy_mean+regularization
#设置指数衰减的学习率
learning_rate=tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,#当前迭代轮数
mnist.train.num_examples/BATCH_SIZE,#过完所有训练数据的迭代次数
LEARNING_RATE_DECAY)
#优化损失函数
train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(
loss,global_step=global_step)
#反向传播同时更新神经网络参数及其滑动平均值
withtf.control_dependencies([train_step,variables_avgs_op]):
train_op=tf.no_op(name='train')
#检验使用了滑动平均模型的神经网络前向传播结果是否正确
correct_prediction=tf.equal(tf.argmax(avg_y,1),tf.argmax(y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#初始化会话并开始训练
withtf.Session()assess:
tf.global_variables_initializer().run()
#准备验证数据,用于判断停止条件和训练效果
validate_feed={x:mnist.validation.images,
y_:mnist.validation.labels}
#准备测试数据,用于模型优劣的最后评价标准
test_feed={x:mnist.test.images,y_:mnist.test.labels}
#迭代训练神经网络
foriinrange(TRAINING_STEPS):
ifi%1000==0:
validate_acc=sess.run(accuracy,feed_dict=validate_feed)
print("After%dtrainingstep(s),validationaccuracyusingaverage"
"modelis%g"%(i,validate_acc))
xs,ys=mnist.train.next_batch(BATCH_SIZE)
sess.run(train_op,feed_dict={x:xs,y_:ys})
#训练结束后在测试集上检测模型的最终正确率
test_acc=sess.run(accuracy,feed_dict=test_feed)
print("After%dtrainingsteps,testaccuracyusingaveragemodel"
"is%g"%(TRAINING_STEPS,test_acc))
#主程序入口
defmain(argv=None):
mnist=input_data.read_data_sets("/tmp/data",one_hot=True)
train(mnist)
#Tensorflow主程序入口
if__name__=='__main__':
tf.app.run()
输出结果如下:
Extracting/tmp/data/train-images-idx3-ubyte.gz Extracting/tmp/data/train-labels-idx1-ubyte.gz Extracting/tmp/data/t10k-images-idx3-ubyte.gz Extracting/tmp/data/t10k-labels-idx1-ubyte.gz After0trainingstep(s),validationaccuracyusingaveragemodelis0.0462 After1000trainingstep(s),validationaccuracyusingaveragemodelis0.9784 After2000trainingstep(s),validationaccuracyusingaveragemodelis0.9806 After3000trainingstep(s),validationaccuracyusingaveragemodelis0.9798 After4000trainingstep(s),validationaccuracyusingaveragemodelis0.9814 After5000trainingstep(s),validationaccuracyusingaveragemodelis0.9826 After6000trainingstep(s),validationaccuracyusingaveragemodelis0.9828 After7000trainingstep(s),validationaccuracyusingaveragemodelis0.9832 After8000trainingstep(s),validationaccuracyusingaveragemodelis0.9838 After9000trainingstep(s),validationaccuracyusingaveragemodelis0.983 After10000trainingstep(s),validationaccuracyusingaveragemodelis0.9836 After11000trainingstep(s),validationaccuracyusingaveragemodelis0.9822 After12000trainingstep(s),validationaccuracyusingaveragemodelis0.983 After13000trainingstep(s),validationaccuracyusingaveragemodelis0.983 After14000trainingstep(s),validationaccuracyusingaveragemodelis0.9844 After15000trainingstep(s),validationaccuracyusingaveragemodelis0.9832 After16000trainingstep(s),validationaccuracyusingaveragemodelis0.9844 After17000trainingstep(s),validationaccuracyusingaveragemodelis0.9842 After18000trainingstep(s),validationaccuracyusingaveragemodelis0.9842 After19000trainingstep(s),validationaccuracyusingaveragemodelis0.9838 After20000trainingstep(s),validationaccuracyusingaveragemodelis0.9834 After21000trainingstep(s),validationaccuracyusingaveragemodelis0.9828 After22000trainingstep(s),validationaccuracyusingaveragemodelis0.9834 After23000trainingstep(s),validationaccuracyusingaveragemodelis0.9844 After24000trainingstep(s),validationaccuracyusingaveragemodelis0.9838 After25000trainingstep(s),validationaccuracyusingaveragemodelis0.9834 After26000trainingstep(s),validationaccuracyusingaveragemodelis0.984 After27000trainingstep(s),validationaccuracyusingaveragemodelis0.984 After28000trainingstep(s),validationaccuracyusingaveragemodelis0.9836 After29000trainingstep(s),validationaccuracyusingaveragemodelis0.9842 After30000trainingsteps,testaccuracyusingaveragemodelis0.9839
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。