基于Tensorflow批量数据的输入实现方式
基于Tensorflow下的批量数据的输入处理:
1.TensorTFrecords格式
2.h5py的库的数组方法
在tensorflow的框架下写CNN代码,我在书写过程中,感觉不是框架内容难写,更多的是我在对图像的预处理和输入这部分花了很多精神。
使用了两种方法:
方法一:
Tensor以Tfrecords的格式存储数据,如果对数据进行标签,可以同时做到数据打标签。
①创建TFrecords文件
orig_image='/home/images/train_image/'
gen_image='/home/images/image_train.tfrecords'
defcreate_record():
writer=tf.python_io.TFRecordWriter(gen_image)
class_path=orig_image
forimg_nameinos.listdir(class_path):#读取每一幅图像
img_path=class_path+img_name
img=Image.open(img_path)#读取图像
#img=img.resize((256,256))#设置图片大小,在这里可以对图像进行处理
img_raw=img.tobytes()#将图片转化为原声bytes
example=tf.train.Example(
features=tf.train.Features(feature={
'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[0])),#打标签
'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))#存储数据
}))
writer.write(example.SerializeToString())
writer.close()
②读取TFrecords文件
defread_and_decode(filename):
#创建文件队列,不限读取的数据
filename_queue=tf.train.string_input_producer([filename])
reader=tf.TFRecordReader()
_,serialized_example=reader.read(filename_queue)
features=tf.parse_single_example(
serialized_example,
features={
'label':tf.FixedLenFeature([],tf.int64),
'img_raw':tf.FixedLenFeature([],tf.string)})
label=features['label']
img=features['img_raw']
img=tf.decode_raw(img,tf.uint8)#tf.float32
img=tf.image.convert_image_dtype(img,dtype=tf.float32)
img=tf.reshape(img,[256,256,1])
label=tf.cast(label,tf.int32)
returnimg,label
③批量读取数据,使用tf.train.batch
min_after_dequeue=10000 capacity=min_after_dequeue+3*batch_size num_samples=len(os.listdir(orig_image)) create_record() img,label=read_and_decode(gen_image) total_batch=int(num_samples/batch_size) image_batch,label_batch=tf.train.batch([img,label],batch_size=batch_size, num_threads=32,capacity=capacity) init_op=tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()) withtf.Session()assess: sess.run(init_op) coord=tf.train.Coordinator() threads=tf.train.start_queue_runners(coord=coord) foriinrange(total_batch): cur_image_batch,cur_label_batch=sess.run([image_batch,label_batch]) coord.request_stop() coord.join(threads)
方法二:
使用h5py就是使用数组的格式来存储数据
这个方法比较好,在CNN的过程中,会使用到多个数据类存储,比较好用,比如一个数据进行了两种以上的变化,并且分类存储,我认为这个方法会比较好用。
importos
importh5py
importmatplotlib.pyplotasplt
importnumpyasnp
importrandom
fromscipy.interpolateimportgriddata
fromskimageimportimg_as_float
importmatplotlib.pyplotasplt
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
class_path='/home/awen/Juanjuan/PythonProject/train_BSDS/test_gray_0_1/'
forimg_nameinos.listdir(class_path):
img_path=class_path+img_name
img=io.imread(img_path)
m1=img_as_float(img)
m2,m3=sample_inter1(m1)#一个数据处理的函数
m1=m1.reshape([256,256,1])
m2=m2.reshape([256,256,1])
m3=m3.reshape([256,256,1])
orig_image.append(m1)
sample_near.append(m2)
sample_line.append(m3)
arrorig_image=np.asarray(orig_image)#[?,256,256,1]
arrlsample_near=np.asarray(sample_near)#[?,256,256,1]
arrlsample_line=np.asarray(sample_line)#[?,256,256,1]
save_path='/home/awen/Juanjuan/PythonProject/train_BSDS/test_sample/train.h5'
defmake_data(path):
withh5py.File(save_path,'w')ashf:
hf.create_dataset('orig_image',data=arrorig_image)
hf.create_dataset('sample_near',data=arrlsample_near)
hf.create_dataset('sample_line',data=arrlsample_line)
defread_data(path):
withh5py.File(path,'r')ashf:
orig_image=np.array(hf.get('orig_image'))#一定要对清楚上边的标签名orig_image;
sample_near=np.array(hf.get('sample_near'))
sample_line=np.array(hf.get('sample_line'))
returnorig_image,sample_near,sample_line
make_data(save_path)
orig_image1,sample_near1,sample_line1=read_data(save_path)
total_number=len(orig_image1)
batch_size=20
batch_index=total_number/batch_size
foriinrange(batch_index):
batch_orig=orig_image1[i*batch_size:(i+1)*batch_size]
batch_sample_near=sample_near1[i*batch_size:(i+1)*batch_size]
batch_sample_line=sample_line1[i*batch_size:(i+1)*batch_size]
在使用h5py的时候,生成的文件巨大的时候,读取数据显示错误:ioerror:unabletoopenfile(badobjectheaderversionnumber)
基本就是这个生成的文件不能使用,适当的减少存储的数据,即可。
以上这篇基于Tensorflow批量数据的输入实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。