使用tensorflow DataSet实现高效加载变长文本输入
DataSet是tensorflow1.3版本推出的一个high-level的api,在1.3版本还只是处于测试阶段,1.4版本已经正式推出。
在网上搜了一遍,发现关于使用DataSet加载文本的资料比较少,官方举的例子只是csv格式的,要求csv文件中所有样本必须具有相同的维度,也就是padding必须在写入csv文件之前做掉,这会增加文件的大小。
经过一番折腾试验,这里给出一个DataSet+TFRecords加载变长样本的范例。
首先先把变长的数据写入到TFRecords文件:
defwritedata(): xlist=[[1,2,3],[4,5,6,8]] ylist=[1,2] #这里的数据只是举个例子来说明样本的文本长度不一样,第一个样本3个词标签1,第二个样本4个词标签2 writer=tf.python_io.TFRecordWriter("train.tfrecords") foriinrange(2): x=xlist[i] y=ylist[i] example=tf.train.Example(features=tf.train.Features(feature={ "y":tf.train.Feature(int64_list=tf.train.Int64List(value=[y])), 'x':tf.train.Feature(int64_list=tf.train.Int64List(value=x)) })) writer.write(example.SerializeToString()) writer.close()
然后用DataSet加载:
feature_names=['x'] defmy_input_fn(file_path,perform_shuffle=False,repeat_count=1): defparse(example_proto): features={"x":tf.VarLenFeature(tf.int64), "y":tf.FixedLenFeature([1],tf.int64)} parsed_features=tf.parse_single_example(example_proto,features) x=tf.sparse_tensor_to_dense(parsed_features["x"]) x=tf.cast(x,tf.int32) x=dict(zip(feature_names,[x])) y=tf.cast(parsed_features["y"],tf.int32) returnx,y dataset=(tf.contrib.data.TFRecordDataset(file_path) .map(parse)) ifperform_shuffle: dataset=dataset.shuffle(buffer_size=256) dataset=dataset.repeat(repeat_count) dataset=dataset.padded_batch(2,padded_shapes=({'x':[6]},[1]))#batchsize为2,并且x按maxlen=6来做padding iterator=dataset.make_one_shot_iterator() batch_features,batch_labels=iterator.get_next() returnbatch_features,batch_labels next_batch=my_input_fn('train.tfrecords',True) init=tf.initialize_all_variables() withtf.Session()assess: sess.run(init) foriinrange(1): xs,y=sess.run(next_batch) print(xs['x']) print(y)
注意变长的数据TFRecords解析要用VarLenFeature,然后用sparse_tensor_to_dense转换。
以上这篇使用tensorflowDataSet实现高效加载变长文本输入就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。