关于Tensorflow中的tf.train.batch函数的使用
这两天一直在看tensorflow中的读取数据的队列,说实话,真的是很难懂。也可能我之前没这方面的经验吧,最早我都使用的theano,什么都是自己写。经过这两天的文档以及相关资料,并且请教了国内的师弟。今天算是有点小感受了。简单的说,就是计算图是从一个管道中读取数据的,录入管道是用的现成的方法,读取也是。为了保证多线程的时候从一个管道读取数据不会乱吧,所以这种时候读取的时候需要线程管理的相关操作。今天我实验室了一个简单的操作,就是给一个有序的数据,看看读出来是不是有序的,结果发现是有序的,所以直接给代码:
importtensorflowastf
importnumpyasnp
defgenerate_data():
num=25
label=np.asarray(range(0,num))
images=np.random.random([num,5,5,3])
print('labelsize:{},imagesize{}'.format(label.shape,images.shape))
returnlabel,images
defget_batch_data():
label,images=generate_data()
images=tf.cast(images,tf.float32)
label=tf.cast(label,tf.int32)
input_queue=tf.train.slice_input_producer([images,label],shuffle=False)
image_batch,label_batch=tf.train.batch(input_queue,batch_size=10,num_threads=1,capacity=64)
returnimage_batch,label_batch
image_batch,label_batch=get_batch_data()
withtf.Session()assess:
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess,coord)
i=0
try:
whilenotcoord.should_stop():
image_batch_v,label_batch_v=sess.run([image_batch,label_batch])
i+=1
forjinrange(10):
print(image_batch_v.shape,label_batch_v[j])
excepttf.errors.OutOfRangeError:
print("done")
finally:
coord.request_stop()
coord.join(threads)
记得那个slice_input_producer方法,默认是要shuffle的哈。
Besides,Iwouldliketocommentthiscode.
1:thereisaparameter‘num_epochs'inslice_input_producer,whichcontrolshowmanyepochstheslice_input_producermethodwouldwork.whenthismethodrunsthespecifiedepochs,itwouldreporttheOutOfRangeRrror.Ithinkitwouldbeusefulforourcontrolthetrainingepochs.
2:theoutputofthismethodisonesingleimage,wecouldoperatethissingleimagewithtensorflowAPI,suchasnormalization,crops,andsoon,thenthissingleimageisfeedtobatchmethod,abatchofimagesfortrainingortestingwouldbereceived.
tf.train.batch和tf.train.shuffle_batch的区别用法
tf.train.batch([example,label],batch_size=batch_size,capacity=capacity):[example,label]表示样本和样本标签,这个可以是一个样本和一个样本标签,batch_size是返回的一个batch样本集的样本个数。capacity是队列中的容量。这主要是按顺序组合成一个batch
tf.train.shuffle_batch([example,label],batch_size=batch_size,capacity=capacity,min_after_dequeue)。这里面的参数和上面的一样的意思。不一样的是这个参数min_after_dequeue,一定要保证这参数小于capacity参数的值,否则会出错。这个代表队列中的元素大于它的时候就输出乱的顺序的batch。也就是说这个函数的输出结果是一个乱序的样本排列的batch,不是按照顺序排列的。
上面的函数返回值都是一个batch的样本和样本标签,只是一个是按照顺序,另外一个是随机的
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。