关于Tensorflow中的tf.train.batch函数的使用
这两天一直在看tensorflow中的读取数据的队列,说实话,真的是很难懂。也可能我之前没这方面的经验吧,最早我都使用的theano,什么都是自己写。经过这两天的文档以及相关资料,并且请教了国内的师弟。今天算是有点小感受了。简单的说,就是计算图是从一个管道中读取数据的,录入管道是用的现成的方法,读取也是。为了保证多线程的时候从一个管道读取数据不会乱吧,所以这种时候读取的时候需要线程管理的相关操作。今天我实验室了一个简单的操作,就是给一个有序的数据,看看读出来是不是有序的,结果发现是有序的,所以直接给代码:
importtensorflowastf importnumpyasnp defgenerate_data(): num=25 label=np.asarray(range(0,num)) images=np.random.random([num,5,5,3]) print('labelsize:{},imagesize{}'.format(label.shape,images.shape)) returnlabel,images defget_batch_data(): label,images=generate_data() images=tf.cast(images,tf.float32) label=tf.cast(label,tf.int32) input_queue=tf.train.slice_input_producer([images,label],shuffle=False) image_batch,label_batch=tf.train.batch(input_queue,batch_size=10,num_threads=1,capacity=64) returnimage_batch,label_batch image_batch,label_batch=get_batch_data() withtf.Session()assess: coord=tf.train.Coordinator() threads=tf.train.start_queue_runners(sess,coord) i=0 try: whilenotcoord.should_stop(): image_batch_v,label_batch_v=sess.run([image_batch,label_batch]) i+=1 forjinrange(10): print(image_batch_v.shape,label_batch_v[j]) excepttf.errors.OutOfRangeError: print("done") finally: coord.request_stop() coord.join(threads)
记得那个slice_input_producer方法,默认是要shuffle的哈。
Besides,Iwouldliketocommentthiscode.
1:thereisaparameter‘num_epochs'inslice_input_producer,whichcontrolshowmanyepochstheslice_input_producermethodwouldwork.whenthismethodrunsthespecifiedepochs,itwouldreporttheOutOfRangeRrror.Ithinkitwouldbeusefulforourcontrolthetrainingepochs.
2:theoutputofthismethodisonesingleimage,wecouldoperatethissingleimagewithtensorflowAPI,suchasnormalization,crops,andsoon,thenthissingleimageisfeedtobatchmethod,abatchofimagesfortrainingortestingwouldbereceived.
tf.train.batch和tf.train.shuffle_batch的区别用法
tf.train.batch([example,label],batch_size=batch_size,capacity=capacity):[example,label]表示样本和样本标签,这个可以是一个样本和一个样本标签,batch_size是返回的一个batch样本集的样本个数。capacity是队列中的容量。这主要是按顺序组合成一个batch
tf.train.shuffle_batch([example,label],batch_size=batch_size,capacity=capacity,min_after_dequeue)。这里面的参数和上面的一样的意思。不一样的是这个参数min_after_dequeue,一定要保证这参数小于capacity参数的值,否则会出错。这个代表队列中的元素大于它的时候就输出乱的顺序的batch。也就是说这个函数的输出结果是一个乱序的样本排列的batch,不是按照顺序排列的。
上面的函数返回值都是一个batch的样本和样本标签,只是一个是按照顺序,另外一个是随机的
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。