tensorflow使用range_input_producer多线程读取数据实例
先放关键代码:
i=tf.train.range_input_producer(NUM_EXPOCHES,num_epochs=1,shuffle=False).dequeue() inputs=tf.slice(array,[i*BATCH_SIZE],[BATCH_SIZE])
原理解析:
第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;
0,1,2,0,1,2
队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。
如果num_epochs不指定,则队列内容是这样子:
0,1,2,0,1,2,0,1,2,0,1,2...
队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。
下面是完整的演示代码。
数据文件test.txt内容:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
main.py内容:
importtensorflowastf
importcodecs
BATCH_SIZE=6
NUM_EXPOCHES=5
definput_producer():
array=codecs.open("test.txt").readlines()
array=map(lambdaline:line.strip(),array)
i=tf.train.range_input_producer(NUM_EXPOCHES,num_epochs=1,shuffle=False).dequeue()
inputs=tf.slice(array,[i*BATCH_SIZE],[BATCH_SIZE])
returninputs
classInputs(object):
def__init__(self):
self.inputs=input_producer()
defmain(*args,**kwargs):
inputs=Inputs()
init=tf.group(tf.initialize_all_variables(),
tf.initialize_local_variables())
sess=tf.Session()
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)
sess.run(init)
try:
index=0
whilenotcoord.should_stop()andindex<10:
datalines=sess.run(inputs.inputs)
index+=1
print("step:%d,batchdata:%s"%(index,str(datalines)))
excepttf.errors.OutOfRangeError:
print("Donetraing:-------Epochlimitreached")
exceptKeyboardInterrupt:
print("keyboardinterrputdetected,stoptraining")
finally:
coord.request_stop()
coord.join(threads)
sess.close()
delsess
if__name__=="__main__":
main()
输出:
step:1,batchdata:['1''2''3''4''5''6'] step:2,batchdata:['7''8''9''10''11''12'] step:3,batchdata:['13''14''15''16''17''18'] step:4,batchdata:['19''20''21''22''23''24'] step:5,batchdata:['25''26''27''28''29''30'] Donetraing:-------Epochlimitreached
如果range_input_producer去掉参数num_epochs=1,则输出:
step:1,batchdata:['1''2''3''4''5''6'] step:2,batchdata:['7''8''9''10''11''12'] step:3,batchdata:['13''14''15''16''17''18'] step:4,batchdata:['19''20''21''22''23''24'] step:5,batchdata:['25''26''27''28''29''30'] step:6,batchdata:['1''2''3''4''5''6'] step:7,batchdata:['7''8''9''10''11''12'] step:8,batchdata:['13''14''15''16''17''18'] step:9,batchdata:['19''20''21''22''23''24'] step:10,batchdata:['25''26''27''28''29''30']
有一点需要注意,文件总共有35条数据,BATCH_SIZE=6表示每个batch包含6条数据,NUM_EXPOCHES=5表示产生5个batch,如果NUM_EXPOCHES=6,则总共需要36条数据,就会报如下错误:
InvalidArgumentError(seeabovefortraceback):Expectedsize[0]in[0,5],butgot6 [[Node:Slice=Slice[Index=DT_INT32,T=DT_STRING,_device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input,Slice/begin/_5,Slice/size)]]
错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES的取值能只能在0到5之间。
以上这篇tensorflow使用range_input_producer多线程读取数据实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。