初探TensorFLow从文件读取图片的四种方式
本文记录一下TensorFLow的几种图片读取方法,官方文档有较为全面的介绍。
1.使用gfile读图片,decode输出是Tensor,eval后是ndarray
importmatplotlib.pyplotasplt
importtensorflowastf
importnumpyasnp
print(tf.__version__)
image_raw=tf.gfile.FastGFile('test/a.jpg','rb').read()#bytes
img=tf.image.decode_jpeg(image_raw)#Tensor
#img2=tf.image.convert_image_dtype(img,dtype=tf.uint8)
withtf.Session()assess:
print(type(image_raw))#bytes
print(type(img))#Tensor
#print(type(img2))
print(type(img.eval()))#ndarray!!!
print(img.eval().shape)
print(img.eval().dtype)
#print(type(img2.eval()))
#print(img2.eval().shape)
#print(img2.eval().dtype)
plt.figure(1)
plt.imshow(img.eval())
plt.show()
输出为:
1.3.0
(666,1000,3)
uint8
图片显示(略)
2.使用WholeFileReader输入queue,decode输出是Tensor,eval后是ndarray
importtensorflowastf
importos
importmatplotlib.pyplotasplt
deffile_name(file_dir):#来自https://www.nhooo.com/article/134543.htm
forroot,dirs,filesinos.walk(file_dir):#模块os中的walk()函数遍历文件夹下所有的文件
print(root)#当前目录路径
print(dirs)#当前路径下所有子目录
print(files)#当前路径下所有非目录子文件
deffile_name2(file_dir):#特定类型的文件
L=[]
forroot,dirs,filesinos.walk(file_dir):
forfileinfiles:
ifos.path.splitext(file)[1]=='.jpg':
L.append(os.path.join(root,file))
returnL
path=file_name2('test')
#以下参考https://www.nhooo.com/article/134547.htm(十图详解TensorFlow数据读取机制)
#path2=tf.train.match_filenames_once(path)
file_queue=tf.train.string_input_producer(path,shuffle=True,num_epochs=2)#创建输入队列
image_reader=tf.WholeFileReader()
key,image=image_reader.read(file_queue)
image=tf.image.decode_jpeg(image)
withtf.Session()assess:
#coord=tf.train.Coordinator()#协同启动的线程
#threads=tf.train.start_queue_runners(sess=sess,coord=coord)#启动线程运行队列
#coord.request_stop()#停止所有的线程
#coord.join(threads)
tf.local_variables_initializer().run()
threads=tf.train.start_queue_runners(sess=sess)
#print(type(image))
#print(type(image.eval()))
#print(image.eval().shape)
for_inpath+path:
plt.figure
plt.imshow(image.eval())
plt.show()
3.使用read_file,decode输出是Tensor,eval后是ndarray
importmatplotlib.pyplotasplt
importtensorflowastf
importnumpyasnp
print(tf.__version__)
image_value=tf.read_file('test/a.jpg')
img=tf.image.decode_jpeg(image_value,channels=3)
withtf.Session()assess:
print(type(image_value))#bytes
print(type(img))#Tensor
#print(type(img2))
print(type(img.eval()))#ndarray!!!
print(img.eval().shape)
print(img.eval().dtype)
#print(type(img2.eval()))
#print(img2.eval().shape)
#print(img2.eval().dtype)
plt.figure(1)
plt.imshow(img.eval())
plt.show()
输出是:
1.3.0
(666,1000,3)
uint8
显示图片(略)
4.TFRecords:
有空再看。
如果图片是根据分类放在不同的文件夹下,那么可以直接使用如下代码:
https://www.nhooo.com/article/134532.htm
https://www.nhooo.com/article/134539.htm
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。
