Tensorflow使用tfrecord输入数据格式
Tensorflow提供了一种统一的格式来存储数据,这个格式就是TFRecord,上一篇文章中所提到的方法当数据的来源更复杂,每个样例中的信息更丰富的时候就很难有效的记录输入数据中的信息了,于是Tensorflow提供了TFRecord来统一存储数据,接下来我们就来介绍如何使用TFRecord来同意输入数据的格式。
1.TFRecord格式介绍
TFRecord文件中的数据是通过tf.train.ExampleProtocolBuffer的格式存储的,下面是tf.train.Example的定义
messageExample{
Featuresfeatures=1;
};
messageFeatures{
mapfeatrue=1;
};
messageFeature{
oneofkind{
BytesListbytes_list=1;
FloatListfloat_list=2;
Int64Listint64_list=3;
}
};
从上述代码可以看到,ft.train.Example的数据结构相对简洁。tf.train.Example中包含了一个从属性名称到取值的字典,其中属性名称为一个字符串,属性的取值可以为字符串(BytesList),实数列表(FloatList)或整数列表(Int64List)。例如我们可以将解码前的图片作为字符串,图像对应的类别标号作为整数列表。
2.将自己的数据转化为TFRecord格式
准备数据
在上一篇中,我们为了像伟大的MNIST致敬,所以选择图像的前缀来进行不同类别的分类依据,但是大多数的情况下,在进行分类任务的过程中,不同的类别都会放在不同的文件夹下,而且类别的个数往往浮动性又很大,所以针对这样的情况,我们现在利用不同类别在不同文件夹中的图像来生成TFRecord.
我们在Iris&Contact这个文件夹下有两个文件夹,分别为iris,contact。对于每个文件夹中存放的是对应的图片
转换数据
数据准备好以后,就开始准备生成TFRecord,具体代码如下:
importos
importtensorflowastf
fromPILimportImage
importmatplotlib.pyplotasplt
cwd='/home/ruyiwei/Documents/Iris&Contact/'
classes={'iris','contact'}
writer=tf.python_io.TFRecordWriter("iris_contact.tfrecords")
forindex,nameinenumerate(classes):
class_path=cwd+name+'/'
forimg_nameinos.listdir(class_path):
img_path=class_path+img_name
img=Image.open(img_path)
img=img.resize((512,80))
img_raw=img.tobytes()
#plt.imshow(img)#ifyouwanttocheckyouimage,pleasedelete'#'
#plt.show()
example=tf.train.Example(features=tf.train.Features(feature={
"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
3.Tensorflow从TFRecord中读取数据
defread_and_decode(filename):#readiris_contact.tfrecords
filename_queue=tf.train.string_input_producer([filename])#createaqueue
reader=tf.TFRecordReader()
_,serialized_example=reader.read(filename_queue)#returnfile_nameandfile
features=tf.parse_single_example(serialized_example,
features={
'label':tf.FixedLenFeature([],tf.int64),
'img_raw':tf.FixedLenFeature([],tf.string),
})#returnimageandlabel
img=tf.decode_raw(features['img_raw'],tf.uint8)
img=tf.reshape(img,[512,80,3])#reshapeimageto512*80*3
img=tf.cast(img,tf.float32)*(1./255)-0.5#throwimgtensor
label=tf.cast(features['label'],tf.int32)#throwlabeltensor
returnimg,label
4.将TFRecord中的数据保存为图片
filename_queue=tf.train.string_input_producer(["iris_contact.tfrecords"])
reader=tf.TFRecordReader()
_,serialized_example=reader.read(filename_queue)#returnfileandfile_name
features=tf.parse_single_example(serialized_example,
features={
'label':tf.FixedLenFeature([],tf.int64),
'img_raw':tf.FixedLenFeature([],tf.string),
})
image=tf.decode_raw(features['img_raw'],tf.uint8)
image=tf.reshape(image,[512,80,3])
label=tf.cast(features['label'],tf.int32)
withtf.Session()assess:
init_op=tf.initialize_all_variables()
sess.run(init_op)
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(coord=coord)
foriinrange(20):
example,l=sess.run([image,label])#takeoutimageandlabel
img=Image.fromarray(example,'RGB')
img.save(cwd+str(i)+'_''Label_'+str(l)+'.jpg')#saveimage
print(example,l)
coord.request_stop()
coord.join(threads)
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。