tensorflow tf.train.batch之数据批量读取方式
在进行大量数据训练神经网络的时候,可能需要批量读取数据。于是参考了这篇文章的代码,结果发现数据一直批量循环输出,不会在数据的末尾自动停止。
然后发现这篇博文说slice_input_producer()这个函数有一个形参num_epochs,通过设置它的值就可以控制全部数据循环输出几次。
于是我设置之后出现以下的报错:
tensorflow.python.framework.errors_impl.FailedPreconditionError:Attemptingtouseuninitializedvalueinput_producer/input_producer/limit_epochs/epochs [[Node:input_producer/input_producer/limit_epochs/CountUpTo=CountUpTo[T=DT_INT64,_class=["loc:@input_producer/input_producer/limit_epochs/epochs"],limit=2,_device="/job:localhost/replica:0/task:0/cpu:0"](input_producer/input_producer/limit_epochs/epochs)]]
找了好久,都不知道为什么会错,于是只好去看看slice_input_producer()函数的源码,结果在源码中发现作者说这个num_epochs如果不是空的话,就是一个局部变量,需要先调用global_variables_initializer()函数初始化。
于是我调用了之后,一切就正常了,特此记录下来,希望其他人遇到的时候能够及时找到原因。
哈哈,这是笔者第一次通过阅读源码解决了问题,心情还是有点小激动。啊啊,扯远了,上最终成功的代码:
importpandasaspd
importnumpyasnp
importtensorflowastf
defgenerate_data():
num=25
label=np.asarray(range(0,num))
images=np.random.random([num,5])
print('labelsize:{},imagesize{}'.format(label.shape,images.shape))
returnimages,label
defget_batch_data():
label,images=generate_data()
input_queue=tf.train.slice_input_producer([images,label],shuffle=False,num_epochs=2)
image_batch,label_batch=tf.train.batch(input_queue,batch_size=5,num_threads=1,capacity=64,allow_smaller_final_batch=False)
returnimage_batch,label_batch
images,label=get_batch_data()
sess=tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())#就是这一行
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess,coord)
try:
whilenotcoord.should_stop():
i,l=sess.run([images,label])
print(i)
print(l)
excepttf.errors.OutOfRangeError:
print('Donetraining')
finally:
coord.request_stop()
coord.join(threads)
sess.close()
以上这篇tensorflowtf.train.batch之数据批量读取方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。