浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点
batch很好理解,就是batchsize。注意在一个epoch中最后一个batch大小可能小于等于batchsize
dataset.repeat就是俗称epoch,但在tf中与dataset.shuffle的使用顺序可能会导致个epoch的混合
dataset.shuffle就是说维持一个buffersize大小的shufflebuffer,图中所需的每个样本从shufflebuffer中获取,取得一个样本后,就从源数据集中加入一个样本到shufflebuffer中。
importos os.environ['CUDA_VISIBLE_DEVICES']="" importnumpyasnp importtensorflowastf np.random.seed(0) x=np.random.sample((11,2)) #makeadatasetfromanumpyarray print(x) print() dataset=tf.data.Dataset.from_tensor_slices(x) dataset=dataset.shuffle(3) dataset=dataset.batch(4) dataset=dataset.repeat(2) #createtheiterator iter=dataset.make_one_shot_iterator() el=iter.get_next() withtf.Session()assess: print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el))
#源数据集 [[0.54881350.71518937] [0.602763380.54488318] [0.42365480.64589411] [0.437587210.891773] [0.963662760.38344152] [0.791725040.52889492] [0.568044560.92559664] [0.071036060.0871293] [0.02021840.83261985] [0.778156750.87001215] [0.978618340.79915856]] #通过shufflebatch后取得的样本 [[0.42365480.64589411] [0.602763380.54488318] [0.437587210.891773] [0.54881350.71518937]] [[0.963662760.38344152] [0.568044560.92559664] [0.02021840.83261985] [0.791725040.52889492]] [[0.071036060.0871293] [0.978618340.79915856] [0.778156750.87001215]]#最后一个batch样本个数为3 [[0.602763380.54488318] [0.54881350.71518937] [0.437587210.891773] [0.791725040.52889492]] [[0.42365480.64589411] [0.568044560.92559664] [0.02021840.83261985] [0.071036060.0871293]] [[0.778156750.87001215] [0.963662760.38344152] [0.978618340.79915856]]#最后一个batch样本个数为3
1、按照shuffle中设置的buffersize,首先从源数据集取得三个样本:
shufflebuffer:
[0.54881350.71518937]
[0.602763380.54488318]
[0.42365480.64589411]
2、从buffer中取一个样本到batch中得:
shufflebuffer:
[0.54881350.71518937]
[0.602763380.54488318]
batch:
[0.42365480.64589411]
3、shufflebuffer不足三个样本,从源数据集提取一个样本:
shufflebuffer:
[0.54881350.71518937]
[0.602763380.54488318]
[0.437587210.891773]
4、从buffer中取一个样本到batch中得:
shufflebuffer:
[0.54881350.71518937]
[0.437587210.891773]
batch:
[0.42365480.64589411]
[0.602763380.54488318]
5、如此反复。这就意味中如果shuffle的buffersize=1,数据集不打乱。如果shuffle的buffersize=数据集样本数量,随机打乱整个数据集
importos os.environ['CUDA_VISIBLE_DEVICES']="" importnumpyasnp importtensorflowastf np.random.seed(0) x=np.random.sample((11,2)) #makeadatasetfromanumpyarray print(x) print() dataset=tf.data.Dataset.from_tensor_slices(x) dataset=dataset.shuffle(1) dataset=dataset.batch(4) dataset=dataset.repeat(2) #createtheiterator iter=dataset.make_one_shot_iterator() el=iter.get_next() withtf.Session()assess: print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) [[0.54881350.71518937] [0.602763380.54488318] [0.42365480.64589411] [0.437587210.891773] [0.963662760.38344152] [0.791725040.52889492] [0.568044560.92559664] [0.071036060.0871293] [0.02021840.83261985] [0.778156750.87001215] [0.978618340.79915856]] [[0.54881350.71518937] [0.602763380.54488318] [0.42365480.64589411] [0.437587210.891773]] [[0.963662760.38344152] [0.791725040.52889492] [0.568044560.92559664] [0.071036060.0871293]] [[0.02021840.83261985] [0.778156750.87001215] [0.978618340.79915856]] [[0.54881350.71518937] [0.602763380.54488318] [0.42365480.64589411] [0.437587210.891773]] [[0.963662760.38344152] [0.791725040.52889492] [0.568044560.92559664] [0.071036060.0871293]] [[0.02021840.83261985] [0.778156750.87001215] [0.978618340.79915856]]
注意如果repeat在shuffle之前使用:
官方说repeat在shuffle之前使用能提高性能,但模糊了数据样本的epoch关系
importos os.environ['CUDA_VISIBLE_DEVICES']="" importnumpyasnp importtensorflowastf np.random.seed(0) x=np.random.sample((11,2)) #makeadatasetfromanumpyarray print(x) print() dataset=tf.data.Dataset.from_tensor_slices(x) dataset=dataset.repeat(2) dataset=dataset.shuffle(11) dataset=dataset.batch(4) #createtheiterator iter=dataset.make_one_shot_iterator() el=iter.get_next() withtf.Session()assess: print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) print(sess.run(el)) [[0.54881350.71518937] [0.602763380.54488318] [0.42365480.64589411] [0.437587210.891773] [0.963662760.38344152] [0.791725040.52889492] [0.568044560.92559664] [0.071036060.0871293] [0.02021840.83261985] [0.778156750.87001215] [0.978618340.79915856]] [[0.568044560.92559664] [0.54881350.71518937] [0.602763380.54488318] [0.071036060.0871293]] [[0.963662760.38344152] [0.437587210.891773] [0.437587210.891773] [0.778156750.87001215]] [[0.791725040.52889492]#出现相同样本出现在同一个batch中 [0.791725040.52889492] [0.602763380.54488318] [0.42365480.64589411]] [[0.071036060.0871293] [0.42365480.64589411] [0.963662760.38344152] [0.54881350.71518937]] [[0.978618340.79915856] [0.02021840.83261985] [0.778156750.87001215] [0.568044560.92559664]] [[0.02021840.83261985] [0.978618340.79915856]]#可以看到最后个batch为2,而前面都是4
使用案例:
definput_fn(filenames,batch_size=32,num_epochs=1,perform_shuffle=False):
print('Parsing',filenames)
defdecode_libsvm(line):
#columns=tf.decode_csv(value,record_defaults=CSV_COLUMN_DEFAULTS)
#features=dict(zip(CSV_COLUMNS,columns))
#labels=features.pop(LABEL_COLUMN)
columns=tf.string_split([line],'')
labels=tf.string_to_number(columns.values[0],out_type=tf.float32)
splits=tf.string_split(columns.values[1:],':')
id_vals=tf.reshape(splits.values,splits.dense_shape)
feat_ids,feat_vals=tf.split(id_vals,num_or_size_splits=2,axis=1)
feat_ids=tf.string_to_number(feat_ids,out_type=tf.int32)
feat_vals=tf.string_to_number(feat_vals,out_type=tf.float32)
#feat_ids=tf.reshape(feat_ids,shape=[-1,FLAGS.field_size])
#foriinrange(splits.dense_shape.eval()[0]):
#feat_ids.append(tf.string_to_number(splits.values[2*i],out_type=tf.int32))
#feat_vals.append(tf.string_to_number(splits.values[2*i+1]))
#returntf.reshape(feat_ids,shape=[-1,field_size]),tf.reshape(feat_vals,shape=[-1,field_size]),labels
return{"feat_ids":feat_ids,"feat_vals":feat_vals},labels
#ExtractlinesfrominputfilesusingtheDatasetAPI,canpassonefilenameorfilenamelist
dataset=tf.data.TextLineDataset(filenames).map(decode_libsvm,num_parallel_calls=10).prefetch(500000)#multi-threadpre-processthenprefetch
#Randomizesinputusingawindowof256elements(readintomemory)
ifperform_shuffle:
dataset=dataset.shuffle(buffer_size=256)
#epochsfromblendingtogether.
dataset=dataset.repeat(num_epochs)
dataset=dataset.batch(batch_size)#Batchsizetouse
#returndataset.make_one_shot_iterator()
iterator=dataset.make_one_shot_iterator()
batch_features,batch_labels=iterator.get_next()
#returntf.reshape(batch_ids,shape=[-1,field_size]),tf.reshape(batch_vals,shape=[-1,field_size]),batch_labels
returnbatch_features,batch_labels
到此这篇关于浅谈tensorflow中dataset.shuffle和dataset.batchdataset.repeat注意点的文章就介绍到这了,更多相关tensorflow中dataset.shuffle和dataset.batchdataset.repeat内容请搜索毛票票以前的文章或继续浏览下面的相关文章希望大家以后多多支持毛票票!