基于Tensorflow批量数据的输入实现方式
基于Tensorflow下的批量数据的输入处理:
1.TensorTFrecords格式
2.h5py的库的数组方法
在tensorflow的框架下写CNN代码,我在书写过程中,感觉不是框架内容难写,更多的是我在对图像的预处理和输入这部分花了很多精神。
使用了两种方法:
方法一:
Tensor以Tfrecords的格式存储数据,如果对数据进行标签,可以同时做到数据打标签。
①创建TFrecords文件
orig_image='/home/images/train_image/' gen_image='/home/images/image_train.tfrecords' defcreate_record(): writer=tf.python_io.TFRecordWriter(gen_image) class_path=orig_image forimg_nameinos.listdir(class_path):#读取每一幅图像 img_path=class_path+img_name img=Image.open(img_path)#读取图像 #img=img.resize((256,256))#设置图片大小,在这里可以对图像进行处理 img_raw=img.tobytes()#将图片转化为原声bytes example=tf.train.Example( features=tf.train.Features(feature={ 'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[0])),#打标签 'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))#存储数据 })) writer.write(example.SerializeToString()) writer.close()
②读取TFrecords文件
defread_and_decode(filename): #创建文件队列,不限读取的数据 filename_queue=tf.train.string_input_producer([filename]) reader=tf.TFRecordReader() _,serialized_example=reader.read(filename_queue) features=tf.parse_single_example( serialized_example, features={ 'label':tf.FixedLenFeature([],tf.int64), 'img_raw':tf.FixedLenFeature([],tf.string)}) label=features['label'] img=features['img_raw'] img=tf.decode_raw(img,tf.uint8)#tf.float32 img=tf.image.convert_image_dtype(img,dtype=tf.float32) img=tf.reshape(img,[256,256,1]) label=tf.cast(label,tf.int32) returnimg,label
③批量读取数据,使用tf.train.batch
min_after_dequeue=10000 capacity=min_after_dequeue+3*batch_size num_samples=len(os.listdir(orig_image)) create_record() img,label=read_and_decode(gen_image) total_batch=int(num_samples/batch_size) image_batch,label_batch=tf.train.batch([img,label],batch_size=batch_size, num_threads=32,capacity=capacity) init_op=tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()) withtf.Session()assess: sess.run(init_op) coord=tf.train.Coordinator() threads=tf.train.start_queue_runners(coord=coord) foriinrange(total_batch): cur_image_batch,cur_label_batch=sess.run([image_batch,label_batch]) coord.request_stop() coord.join(threads)
方法二:
使用h5py就是使用数组的格式来存储数据
这个方法比较好,在CNN的过程中,会使用到多个数据类存储,比较好用,比如一个数据进行了两种以上的变化,并且分类存储,我认为这个方法会比较好用。
importos importh5py importmatplotlib.pyplotasplt importnumpyasnp importrandom fromscipy.interpolateimportgriddata fromskimageimportimg_as_float importmatplotlib.pyplotasplt os.environ['TF_CPP_MIN_LOG_LEVEL']='2' class_path='/home/awen/Juanjuan/PythonProject/train_BSDS/test_gray_0_1/' forimg_nameinos.listdir(class_path): img_path=class_path+img_name img=io.imread(img_path) m1=img_as_float(img) m2,m3=sample_inter1(m1)#一个数据处理的函数 m1=m1.reshape([256,256,1]) m2=m2.reshape([256,256,1]) m3=m3.reshape([256,256,1]) orig_image.append(m1) sample_near.append(m2) sample_line.append(m3) arrorig_image=np.asarray(orig_image)#[?,256,256,1] arrlsample_near=np.asarray(sample_near)#[?,256,256,1] arrlsample_line=np.asarray(sample_line)#[?,256,256,1] save_path='/home/awen/Juanjuan/PythonProject/train_BSDS/test_sample/train.h5' defmake_data(path): withh5py.File(save_path,'w')ashf: hf.create_dataset('orig_image',data=arrorig_image) hf.create_dataset('sample_near',data=arrlsample_near) hf.create_dataset('sample_line',data=arrlsample_line) defread_data(path): withh5py.File(path,'r')ashf: orig_image=np.array(hf.get('orig_image'))#一定要对清楚上边的标签名orig_image; sample_near=np.array(hf.get('sample_near')) sample_line=np.array(hf.get('sample_line')) returnorig_image,sample_near,sample_line make_data(save_path) orig_image1,sample_near1,sample_line1=read_data(save_path) total_number=len(orig_image1) batch_size=20 batch_index=total_number/batch_size foriinrange(batch_index): batch_orig=orig_image1[i*batch_size:(i+1)*batch_size] batch_sample_near=sample_near1[i*batch_size:(i+1)*batch_size] batch_sample_line=sample_line1[i*batch_size:(i+1)*batch_size]
在使用h5py的时候,生成的文件巨大的时候,读取数据显示错误:ioerror:unabletoopenfile(badobjectheaderversionnumber)
基本就是这个生成的文件不能使用,适当的减少存储的数据,即可。
以上这篇基于Tensorflow批量数据的输入实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。