tensorflow使用range_input_producer多线程读取数据实例
先放关键代码:
i=tf.train.range_input_producer(NUM_EXPOCHES,num_epochs=1,shuffle=False).dequeue() inputs=tf.slice(array,[i*BATCH_SIZE],[BATCH_SIZE])
原理解析:
第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;
0,1,2,0,1,2
队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。
如果num_epochs不指定,则队列内容是这样子:
0,1,2,0,1,2,0,1,2,0,1,2...
队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。
下面是完整的演示代码。
数据文件test.txt内容:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
main.py内容:
importtensorflowastf importcodecs BATCH_SIZE=6 NUM_EXPOCHES=5 definput_producer(): array=codecs.open("test.txt").readlines() array=map(lambdaline:line.strip(),array) i=tf.train.range_input_producer(NUM_EXPOCHES,num_epochs=1,shuffle=False).dequeue() inputs=tf.slice(array,[i*BATCH_SIZE],[BATCH_SIZE]) returninputs classInputs(object): def__init__(self): self.inputs=input_producer() defmain(*args,**kwargs): inputs=Inputs() init=tf.group(tf.initialize_all_variables(), tf.initialize_local_variables()) sess=tf.Session() coord=tf.train.Coordinator() threads=tf.train.start_queue_runners(sess=sess,coord=coord) sess.run(init) try: index=0 whilenotcoord.should_stop()andindex<10: datalines=sess.run(inputs.inputs) index+=1 print("step:%d,batchdata:%s"%(index,str(datalines))) excepttf.errors.OutOfRangeError: print("Donetraing:-------Epochlimitreached") exceptKeyboardInterrupt: print("keyboardinterrputdetected,stoptraining") finally: coord.request_stop() coord.join(threads) sess.close() delsess if__name__=="__main__": main()
输出:
step:1,batchdata:['1''2''3''4''5''6'] step:2,batchdata:['7''8''9''10''11''12'] step:3,batchdata:['13''14''15''16''17''18'] step:4,batchdata:['19''20''21''22''23''24'] step:5,batchdata:['25''26''27''28''29''30'] Donetraing:-------Epochlimitreached
如果range_input_producer去掉参数num_epochs=1,则输出:
step:1,batchdata:['1''2''3''4''5''6'] step:2,batchdata:['7''8''9''10''11''12'] step:3,batchdata:['13''14''15''16''17''18'] step:4,batchdata:['19''20''21''22''23''24'] step:5,batchdata:['25''26''27''28''29''30'] step:6,batchdata:['1''2''3''4''5''6'] step:7,batchdata:['7''8''9''10''11''12'] step:8,batchdata:['13''14''15''16''17''18'] step:9,batchdata:['19''20''21''22''23''24'] step:10,batchdata:['25''26''27''28''29''30']
有一点需要注意,文件总共有35条数据,BATCH_SIZE=6表示每个batch包含6条数据,NUM_EXPOCHES=5表示产生5个batch,如果NUM_EXPOCHES=6,则总共需要36条数据,就会报如下错误:
InvalidArgumentError(seeabovefortraceback):Expectedsize[0]in[0,5],butgot6 [[Node:Slice=Slice[Index=DT_INT32,T=DT_STRING,_device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input,Slice/begin/_5,Slice/size)]]
错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES的取值能只能在0到5之间。
以上这篇tensorflow使用range_input_producer多线程读取数据实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。