tensorflow入门:TFRecordDataset变长数据的batch读取详解
在上一篇文章tensorflow入门:tfrecord和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:
1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。
2.使用dataset中的padded_batch方法来进行,参数padded_shapes#指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([],[mx_length],[d1,...,dn]);该方法的函数说明如下:
padded_batch( batch_size, padded_shapes, padding_values=None#默认使用各类型数据的默认值,一般使用时可忽略该项 )
使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:
importtensorflowastf fromtensorflow.contrib.learn.python.learn.datasets.mnistimportread_data_sets mnist=read_data_sets("MNIST_data/",one_hot=True) defget_tfrecords_example(feature,label): tfrecords_features={} feat_shape=feature.shape tfrecords_features['feature']=tf.train.Feature(float_list=tf.train.FloatList(value=feature)) tfrecords_features['shape']=tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape))) tfrecords_features['label']=tf.train.Feature(float_list=tf.train.FloatList(value=label)) returntf.train.Example(features=tf.train.Features(feature=tfrecords_features)) defmake_tfrecord(data,outf_nm='mnist-train'): feats,labels=data outf_nm+='.tfrecord' tfrecord_wrt=tf.python_io.TFRecordWriter(outf_nm) ndatas=len(labels) print(feats[0].dtype,feats[0].shape,ndatas) assertlen(labels[0])>1 forinxinrange(ndatas): ed=random.randint(0,3)#随机丢掉几个数据点,以使长度不等 exmp=get_tfrecords_example(feats[inx][:-ed],labels[inx]) exmp_serial=exmp.SerializeToString() tfrecord_wrt.write(exmp_serial) tfrecord_wrt.close() importrandom nDatas=len(mnist.train.labels) inx_lst=range(nDatas) random.shuffle(inx_lst) random.shuffle(inx_lst) ntrains=int(0.85*nDatas) #maketrainingset data=([mnist.train.images[i]foriininx_lst[:ntrains]],\ [mnist.train.labels[i]foriininx_lst[:ntrains]]) make_tfrecord(data,outf_nm='mnist-train') #makevalidationset data=([mnist.train.images[i]foriininx_lst[ntrains:]],\ [mnist.train.labels[i]foriininx_lst[ntrains:]]) make_tfrecord(data,outf_nm='mnist-val') #maketestset data=(mnist.test.images,mnist.test.labels) make_tfrecord(data,outf_nm='mnist-test')
用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([],tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:
importtensorflowastf train_f,val_f,test_f=['mnist-%s.tfrecord'%iforiin['train','val','test']] defparse_exmp(serial_exmp): feats=tf.parse_single_example(serial_exmp,features={'feature':tf.VarLenFeature(tf.float32),\ 'label':tf.FixedLenFeature([10],tf.float32),'shape':tf.FixedLenFeature([],tf.int64)}) image=tf.sparse_tensor_to_dense(feats['feature'])#使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换 label=tf.reshape(feats['label'],[2,5])#把label变成[2,5],以说明array数据如何padding shape=tf.cast(feats['shape'],tf.int32) returnimage,label,shape defget_dataset(fname): dataset=tf.data.TFRecordDataset(fname) returndataset.map(parse_exmp)#usepadded_batchmethodifpaddingneeded epochs=16 batch_size=50 padded_shapes=([784],[3,5],[])#把imagepad至784,把labelpad至[3,5],shape是一个scalar,不输入数字 #trainingdataset dataset_train=get_dataset(train_f) dataset_train=dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size,padded_shapes=padded_shapes)
以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。