将自己的数据集制作成TFRecord格式教程
在使用TensorFlow训练神经网络时,首先面临的问题是:网络的输入
此篇文章,教大家将自己的数据集制作成TFRecord格式,feed进网络,除了TFRecord格式,TensorFlow也支持其他格
式的数据,此处就不再介绍了。建议大家使用TFRecord格式,在后面可以通过api进行多线程的读取文件队列。
1.原本的数据集
此时,我有两类图片,分别是xiansu100,xiansu60,每一类中有10张图片。
2.制作成TFRecord格式
tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签。如在本例中,只有0,1两类,想知道文件夹名与label关系的,可以自己保存起来。
#生成整数型的属性
def_int64_feature(value):
returntf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
#生成字符串类型的属性
def_bytes_feature(value):
returntf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
#制作TFRecord格式
defcreateTFRecord(filename,mapfile):
class_map={}
data_dir='/home/wc/DataSet/traffic/testTFRecord/'
classes={'xiansu60','xiansu100'}
#输出TFRecord文件的地址
writer=tf.python_io.TFRecordWriter(filename)
forindex,nameinenumerate(classes):
class_path=data_dir+name+'/'
class_map[index]=name
forimg_nameinos.listdir(class_path):
img_path=class_path+img_name#每个图片的地址
img=Image.open(img_path)
img=img.resize((224,224))
img_raw=img.tobytes()#将图片转化成二进制格式
example=tf.train.Example(features=tf.train.Features(feature={
'label':_int64_feature(index),
'image_raw':_bytes_feature(img_raw)
}))
writer.write(example.SerializeToString())
writer.close()
txtfile=open(mapfile,'w+')
forkeyinclass_map.keys():
txtfile.writelines(str(key)+":"+class_map[key]+"\n")
txtfile.close()
此段代码,运行完后会产生生成的.tfrecord文件。
3.读取TFRecord的数据,进行解析,此时使用了文件队列以及多线程
#读取train.tfrecord中的数据
defread_and_decode(filename):
#创建一个reader来读取TFRecord文件中的样例
reader=tf.TFRecordReader()
#创建一个队列来维护输入文件列表
filename_queue=tf.train.string_input_producer([filename],shuffle=False,num_epochs=1)
#从文件中读出一个样例,也可以使用read_up_to一次读取多个样例
_,serialized_example=reader.read(filename_queue)
#print_,serialized_example
#解析读入的一个样例,如果需要解析多个,可以用parse_example
features=tf.parse_single_example(
serialized_example,
features={'label':tf.FixedLenFeature([],tf.int64),
'image_raw':tf.FixedLenFeature([],tf.string),})
#将字符串解析成图像对应的像素数组
img=tf.decode_raw(features['image_raw'],tf.uint8)
img=tf.reshape(img,[224,224,3])#reshape为128*128*3通道图片
img=tf.image.per_image_standardization(img)
labels=tf.cast(features['label'],tf.int32)
returnimg,labels
4.将图片几个一打包,形成batch
defcreateBatch(filename,batchsize): images,labels=read_and_decode(filename) min_after_dequeue=10 capacity=min_after_dequeue+3*batchsize image_batch,label_batch=tf.train.shuffle_batch([images,labels], batch_size=batchsize, capacity=capacity, min_after_dequeue=min_after_dequeue ) label_batch=tf.one_hot(label_batch,depth=2) returnimage_batch,label_batch
5.主函数
if__name__=="__main__":
#训练图片两张为一个batch,进行训练,测试图片一起进行测试
mapfile="/home/wc/DataSet/traffic/testTFRecord/classmap.txt"
train_filename="/home/wc/DataSet/traffic/testTFRecord/train.tfrecords"
#createTFRecord(train_filename,mapfile)
test_filename="/home/wc/DataSet/traffic/testTFRecord/test.tfrecords"
#createTFRecord(test_filename,mapfile)
image_batch,label_batch=createBatch(filename=train_filename,batchsize=2)
test_images,test_labels=createBatch(filename=test_filename,batchsize=20)
withtf.Session()assess:
initop=tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(initop)
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)
try:
step=0
while1:
_image_batch,_label_batch=sess.run([image_batch,label_batch])
step+=1
printstep
print(_label_batch)
excepttf.errors.OutOfRangeError:
print("trainDatadone!")
try:
step=0
while1:
_test_images,_test_labels=sess.run([test_images,test_labels])
step+=1
printstep
#print_image_batch.shape
print(_test_labels)
excepttf.errors.OutOfRangeError:
print("TESTdone!")
coord.request_stop()
coord.join(threads)
此时,生成的batch,就可以feed进网络了。
以上这篇将自己的数据集制作成TFRecord格式教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。