Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取
单一数据读取方式:
第一种:slice_input_producer()
#返回值可以直接通过Session.run([images,labels])查看,且第一个参数必须放在列表中,如[...] [images,labels]=tf.train.slice_input_producer([images,labels],num_epochs=None,shuffle=True)
第二种:string_input_producer()
#需要定义文件读取器,然后通过读取器中的read()方法来获取数据(返回值类型key,value),再通过Session.run(value)查看 file_queue=tf.train.string_input_producer(filename,num_epochs=None,shuffle=True) reader=tf.WholeFileReader()#定义文件读取器 key,value=reader.read(file_queue)#key:文件名;value:文件中的内容
!!!num_epochs=None,不指定迭代次数,这样文件队列中元素个数也不限定(None*数据集大小)。
!!!如果它不是None,则此函数创建本地计数器epochs,需要使用local_variables_initializer()初始化局部变量
!!!以上两种方法都可以生成文件名队列。
(随机)批量数据读取方式:
batchsize=2#每次读取的样本数量 tf.train.batch(tensors,batch_size=batchsize) tf.train.shuffle_batch(tensors,batch_size=batchsize,capacity=batchsize*10,min_after_dequeue=batchsize*5)#capacity>min_after_dequeue
!!!以上所有读取数据的方法,在Session.run()之前必须开启文件队列线程tf.train.start_queue_runners()
TFRecord文件的打包与读取
一、单一数据读取方式
第一种:slice_input_producer()
defslice_input_producer(tensor_list,num_epochs=None,shuffle=True,seed=None,capacity=32,shared_name=None,name=None)
案例1:
importtensorflowastf images=['image1.jpg','image2.jpg','image3.jpg','image4.jpg'] labels=[1,2,3,4] #[images,labels]=tf.train.slice_input_producer([images,labels],num_epochs=None,shuffle=True) #当num_epochs=2时,此时文件队列中只有2*4=8个样本,所有在取第9个样本时会出错 #[images,labels]=tf.train.slice_input_producer([images,labels],num_epochs=2,shuffle=True) data=tf.train.slice_input_producer([images,labels],num_epochs=None,shuffle=True) print(type(data))#withtf.Session()assess: #sess.run(tf.local_variables_initializer()) sess.run(tf.local_variables_initializer()) coord=tf.train.Coordinator()#线程的协调器 threads=tf.train.start_queue_runners(sess,coord)#开始在图表中收集队列运行器 foriinrange(10): print(sess.run(data)) coord.request_stop() coord.join(threads) """
运行结果:
[b'image2.jpg',2]
[b'image1.jpg',1]
[b'image3.jpg',3]
[b'image4.jpg',4]
[b'image2.jpg',2]
[b'image1.jpg',1]
[b'image3.jpg',3]
[b'image4.jpg',4]
[b'image2.jpg',2]
[b'image3.jpg',3]
"""
!!!slice_input_producer()中的第一个参数需要放在一个列表中,列表中的每个元素可以是List或Tensor,如[images,labels],
!!!num_epochs设置
第二种:string_input_producer()
defstring_input_producer(string_tensor,num_epochs=None,shuffle=True,seed=None,capacity=32,shared_name=None,name=None,cancel_op=None)
文件读取器
不同类型的文件对应不同的文件读取器,我们称为reader对象;
该对象的read方法自动读取文件,并创建数据队列,输出key/文件名,value/文件内容;
reader=tf.TextLineReader()###一行一行读取,适用于所有文本文件 reader=tf.TFRecordReader()###AReaderthatoutputstherecordsfromaTFRecordsfile reader=tf.WholeFileReader()###一次读取整个文件,适用图片
案例2:读取csv文件
importtensorflowastf
filename=['data/A.csv','data/B.csv','data/C.csv']
file_queue=tf.train.string_input_producer(filename,shuffle=True,num_epochs=2)#生成文件名队列
reader=tf.WholeFileReader()#定义文件读取器(一次读取整个文件)
#reader=tf.TextLineReader()#定义文件读取器(一行一行的读)
key,value=reader.read(file_queue)#key:文件名;value:文件中的内容
print(type(file_queue))
init=[tf.global_variables_initializer(),tf.local_variables_initializer()]
withtf.Session()assess:
sess.run(init)
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)
try:
whilenotcoord.should_stop():
foriinrange(6):
print(sess.run([key,value]))
break
excepttf.errors.OutOfRangeError:
print('readdone')
finally:
coord.request_stop()
coord.join(threads)
"""
reader=tf.WholeFileReader()#定义文件读取器(一次读取整个文件)
运行结果:
[b'data/C.csv',b'7.jpg,7\n8.jpg,8\n9.jpg,9\n']
[b'data/B.csv',b'4.jpg,4\n5.jpg,5\n6.jpg,6\n']
[b'data/A.csv',b'1.jpg,1\n2.jpg,2\n3.jpg,3\n']
[b'data/A.csv',b'1.jpg,1\n2.jpg,2\n3.jpg,3\n']
[b'data/B.csv',b'4.jpg,4\n5.jpg,5\n6.jpg,6\n']
[b'data/C.csv',b'7.jpg,7\n8.jpg,8\n9.jpg,9\n']
"""
"""
reader=tf.TextLineReader()#定义文件读取器(一行一行的读)
运行结果:
[b'data/B.csv:1',b'4.jpg,4']
[b'data/B.csv:2',b'5.jpg,5']
[b'data/B.csv:3',b'6.jpg,6']
[b'data/C.csv:1',b'7.jpg,7']
[b'data/C.csv:2',b'8.jpg,8']
[b'data/C.csv:3',b'9.jpg,9']
"""
案例3:读取图片(每次读取全部图片内容,不是一行一行)
importtensorflowastf
filename=['1.jpg','2.jpg']
filename_queue=tf.train.string_input_producer(filename,shuffle=False,num_epochs=1)
reader=tf.WholeFileReader()#文件读取器
key,value=reader.read(filename_queue)#读取文件key:文件名;value:图片数据,bytes
withtf.Session()assess:
tf.local_variables_initializer().run()
coord=tf.train.Coordinator()#线程的协调器
threads=tf.train.start_queue_runners(sess,coord)
foriinrange(filename.__len__()):
image_data=sess.run(value)
withopen('img_%d.jpg'%i,'wb')asf:
f.write(image_data)
coord.request_stop()
coord.join(threads)
二、(随机)批量数据读取方式:
功能:shuffle_batch()和batch()这两个API都是从文件队列中批量获取数据,使用方式类似;
案例4:slice_input_producer()与batch()
importtensorflowastf
importnumpyasnp
images=np.arange(20).reshape([10,2])
label=np.asarray(range(0,10))
images=tf.cast(images,tf.float32)#可以注释掉,不影响运行结果
label=tf.cast(label,tf.int32)#可以注释掉,不影响运行结果
batchsize=6#每次获取元素的数量
input_queue=tf.train.slice_input_producer([images,label],num_epochs=None,shuffle=False)
image_batch,label_batch=tf.train.batch(input_queue,batch_size=batchsize)
#随机获取batchsize个元素,其中,capacity:队列容量,这个参数一定要比min_after_dequeue大
#image_batch,label_batch=tf.train.shuffle_batch(input_queue,batch_size=batchsize,capacity=64,min_after_dequeue=10)
withtf.Session()assess:
coord=tf.train.Coordinator()#线程的协调器
threads=tf.train.start_queue_runners(sess,coord)#开始在图表中收集队列运行器
forcntinrange(2):
print("第{}次获取数据,每次batch={}...".format(cnt+1,batchsize))
image_batch_v,label_batch_v=sess.run([image_batch,label_batch])
print(image_batch_v,label_batch_v,label_batch_v.__len__())
coord.request_stop()
coord.join(threads)
"""
运行结果:
第1次获取数据,每次batch=6...
[[0. 1.]
[2. 3.]
[4. 5.]
[6. 7.]
[8. 9.]
[10.11.]][012345]6
第2次获取数据,每次batch=6...
[[12.13.]
[14.15.]
[16.17.]
[18.19.]
[0. 1.]
[2. 3.]][678901]6
"""
案例5:从本地批量的读取图片---string_input_producer()与batch()
importtensorflowastf
importglob
importcv2ascv
defread_imgs(filename,picture_format,input_image_shape,batch_size=):
"""
从本地批量的读取图片
:paramfilename:图片路径(包括图片的文件名),[]
:parampicture_format:图片的格式,如bmp,jpg,png等;string
:paraminput_image_shape:输入图像的大小;(h,w,c)或[]
:parambatch_size:每次从文件队列中加载图片的数量;int
:return:batch_size张图片数据,Tensor
"""
globalnew_img
#创建文件队列
file_queue=tf.train.string_input_producer(filename,num_epochs=1,shuffle=True)
#创建文件读取器
reader=tf.WholeFileReader()
#读取文件队列中的文件
_,img_bytes=reader.read(file_queue)
#print(img_bytes)#Tensor("ReaderReadV2_19:1",shape=(),dtype=string)
#对图片进行解码
ifpicture_format==".bmp":
new_img=tf.image.decode_bmp(img_bytes,channels=1)
elifpicture_format==".jpg":
new_img=tf.image.decode_jpeg(img_bytes,channels=3)
else:
pass
#重新设置图片的大小
#new_img=tf.image.resize_images(new_img,input_image_shape)
new_img=tf.reshape(new_img,input_image_shape)
#设置图片的数据类型
new_img=tf.image.convert_image_dtype(new_img,tf.uint)
#returnnew_img
returntf.train.batch([new_img],batch_size)
defmain():
image_path=glob.glob(r'F:\demo\FaceRecognition\人脸库\ORL\*.bmp')
image_batch=read_imgs(image_path,".bmp",(112,92,1),5)
print(type(image_batch))
#image_path=glob.glob(r'.\*.jpg')
#image_batch=read_imgs(image_path,".jpg",(313,500,3),1)
sess=tf.Session()
sess.run(tf.local_variables_initializer())
tf.train.start_queue_runners(sess=sess)
image_batch=sess.run(image_batch)
print(type(image_batch))#
foriinrange(image_batch.__len__()):
cv.imshow("win_"+str(i),image_batch[i])
cv.waitKey()
cv.destroyAllWindows()
defstart():
image_path=glob.glob(r'F:\demo\FaceRecognition\人脸库\ORL\*.bmp')
image_batch=read_imgs(image_path,".bmp",(112,92,1),5)
print(type(image_batch))#
withtf.Session()assess:
sess.run(tf.local_variables_initializer())
coord=tf.train.Coordinator()#线程的协调器
threads=tf.train.start_queue_runners(sess,coord)#开始在图表中收集队列运行器
image_batch=sess.run(image_batch)
print(type(image_batch))#
foriinrange(image_batch.__len__()):
cv.imshow("win_"+str(i),image_batch[i])
cv.waitKey()
cv.destroyAllWindows()
#若使用with方式打开Session,且没加如下行语句,则会出错
#ERROR:tensorflow:ExceptioninQueueRunner:Enqueueoperationwascancelled;
#原因:文件队列线程还处于工作状态(队列中还有图片数据),而加载完batch_size张图片会话就会自动关闭,同时关闭文件队列线程
coord.request_stop()
coord.join(threads)
if__name__=="__main__":
#main()
start()
案列6:TFRecord文件打包与读取
TFRecord文件打包案列
defwrite_TFRecord(filename,data,labels,is_shuffler=True):
"""
将数据打包成TFRecord格式
:paramfilename:打包后路径名,默认在工程目录下创建该文件;String
:paramdata:需要打包的文件路径名;list
:paramlabels:对应文件的标签;list
:paramis_shuffler:是否随机初始化打包后的数据,默认:True;Bool
:return:None
"""
im_data=list(data)
im_labels=list(labels)
index=[iforiinrange(im_data.__len__())]
ifis_shuffler:
np.random.shuffle(index)
#创建写入器,然后使用该对象写入样本example
writer=tf.python_io.TFRecordWriter(filename)
foriinrange(im_data.__len__()):
im_d=im_data[index[i]]#im_d:存放着第index[i]张图片的路径信息
im_l=im_labels[index[i]]#im_l:存放着对应图片的标签信息
##获取当前的图片数据方式一:
#data=cv2.imread(im_d)
##创建样本
#ex=tf.train.Example(
#features=tf.train.Features(
#feature={
#"image":tf.train.Feature(
#bytes_list=tf.train.BytesList(
#value=[data.tobytes()])),#需要打包成bytes类型
#"label":tf.train.Feature(
#int64_list=tf.train.Int64List(
#value=[im_l])),
#}
#)
#)
#获取当前的图片数据方式二:相对于方式一,打包文件占用空间小了一半多
data=tf.gfile.FastGFile(im_d,"rb").read()
ex=tf.train.Example(
features=tf.train.Features(
feature={
"image":tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[data])),#此时的data已经是bytes类型
"label":tf.train.Feature(
int_list=tf.train.IntList(
value=[im_l])),
}
)
)
#写入将序列化之后的样本
writer.write(ex.SerializeToString())
#关闭写入器
writer.close()
TFReord文件的读取案列
importtensorflowastf
importcv2
defread_TFRecord(file_list,batch_size=):
"""
读取TFRecord文件
:paramfile_list:存放TFRecord的文件名,List
:parambatch_size:每次读取图片的数量
:return:解析后图片及对应的标签
"""
file_queue=tf.train.string_input_producer(file_list,num_epochs=None,shuffle=True)
reader=tf.TFRecordReader()
_,ex=reader.read(file_queue)
batch=tf.train.shuffle_batch([ex],batch_size,capacity=batch_size*10,min_after_dequeue=batch_size*5)
feature={
'image':tf.FixedLenFeature([],tf.string),
'label':tf.FixedLenFeature([],tf.int64)
}
example=tf.parse_example(batch,features=feature)
images=tf.decode_raw(example['image'],tf.uint)
images=tf.reshape(images,[-1,32,32,3])
returnimages,example['label']
defmain():
#filelist=['data/train.tfrecord']
filelist=['data/test.tfrecord']
images,labels=read_TFRecord(filelist,2)
withtf.Session()assess:
sess.run(tf.local_variables_initializer())
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)
try:
whilenotcoord.should_stop():
foriinrange():
image_bth,_=sess.run([images,labels])
print(_)
cv2.imshow("image_0",image_bth[0])
cv2.imshow("image_1",image_bth[1])
break
excepttf.errors.OutOfRangeError:
print('readdone')
finally:
coord.request_stop()
coord.join(threads)
cv2.waitKey(0)
cv2.destroyAllWindows()
if__name__=="__main__":
main()
到此这篇关于Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取的文章就介绍到这了,更多相关TensorflowTFRecord打包与读取内容请搜索毛票票以前的文章或继续浏览下面的相关文章希望大家以后多多支持毛票票!