使用TensorFlow-Slim进行图像分类的实现
参考https://github.com/tensorflow/models/tree/master/slim
使用TensorFlow-Slim进行图像分类
准备
安装TensorFlow
参考https://www.tensorflow.org/install/
如在Ubuntu下安装TensorFlowwithGPUsupport,python2.7版本
wgethttps://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl pipinstalltensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
下载TF-slim图像模型库
cd$WORKSPACE gitclonehttps://github.com/tensorflow/models/
准备数据
有不少公开数据集,这里以官网提供的Flowers为例。
官网提供了下载和转换数据的代码,为了理解代码并能使用自己的数据,这里参考官方提供的代码进行修改。
cd$WORKSPACE/data wgethttp://download.tensorflow.org/example_images/flower_photos.tgz tarzxfflower_photos.tgz
数据集文件夹结构如下:
flower_photos ├──daisy │├──100080576_f52e8ee070_n.jpg │└──... ├──dandelion ├──LICENSE.txt ├──roses ├──sunflowers └──tulips
由于实际情况中我们自己的数据集并不一定把图片按类别放在不同的文件夹里,故我们生成list.txt来表示图片路径与标签的关系。
Python代码:
importos class_names_to_ids={'daisy':0,'dandelion':1,'roses':2,'sunflowers':3,'tulips':4} data_dir='flower_photos/' output_path='list.txt' fd=open(output_path,'w') forclass_nameinclass_names_to_ids.keys(): images_list=os.listdir(data_dir+class_name) forimage_nameinimages_list: fd.write('{}/{}{}\n'.format(class_name,image_name,class_names_to_ids[class_name])) fd.close()
为了方便后期查看label标签,也可以定义labels.txt:
daisy dandelion roses sunflowers tulips
随机生成训练集与验证集:
Python代码:
importrandom _NUM_VALIDATION=350 _RANDOM_SEED=0 list_path='list.txt' train_list_path='list_train.txt' val_list_path='list_val.txt' fd=open(list_path) lines=fd.readlines() fd.close() random.seed(_RANDOM_SEED) random.shuffle(lines) fd=open(train_list_path,'w') forlineinlines[_NUM_VALIDATION:]: fd.write(line) fd.close() fd=open(val_list_path,'w') forlineinlines[:_NUM_VALIDATION]: fd.write(line) fd.close()
生成TFRecord数据:
Python代码:
importsys sys.path.insert(0,'../models/slim/') fromdatasetsimportdataset_utils importmath importos importtensorflowastf defconvert_dataset(list_path,data_dir,output_dir,_NUM_SHARDS=5): fd=open(list_path) lines=[line.split()forlineinfd] fd.close() num_per_shard=int(math.ceil(len(lines)/float(_NUM_SHARDS))) withtf.Graph().as_default(): decode_jpeg_data=tf.placeholder(dtype=tf.string) decode_jpeg=tf.image.decode_jpeg(decode_jpeg_data,channels=3) withtf.Session('')assess: forshard_idinrange(_NUM_SHARDS): output_path=os.path.join(output_dir, 'data_{:05}-of-{:05}.tfrecord'.format(shard_id,_NUM_SHARDS)) tfrecord_writer=tf.python_io.TFRecordWriter(output_path) start_ndx=shard_id*num_per_shard end_ndx=min((shard_id+1)*num_per_shard,len(lines)) foriinrange(start_ndx,end_ndx): sys.stdout.write('\r>>Convertingimage{}/{}shard{}'.format( i+1,len(lines),shard_id)) sys.stdout.flush() image_data=tf.gfile.FastGFile(os.path.join(data_dir,lines[i][0]),'rb').read() image=sess.run(decode_jpeg,feed_dict={decode_jpeg_data:image_data}) height,width=image.shape[0],image.shape[1] example=dataset_utils.image_to_tfexample( image_data,b'jpg',height,width,int(lines[i][1])) tfrecord_writer.write(example.SerializeToString()) tfrecord_writer.close() sys.stdout.write('\n') sys.stdout.flush() os.system('mkdir-ptrain') convert_dataset('list_train.txt','flower_photos','train/') os.system('mkdir-pval') convert_dataset('list_val.txt','flower_photos','val/')
得到的文件夹结构如下:
data ├──flower_photos ├──labels.txt ├──list_train.txt ├──list.txt ├──list_val.txt ├──train │├──data_00000-of-00005.tfrecord │├──... │└──data_00004-of-00005.tfrecord └──val ├──data_00000-of-00005.tfrecord ├──... └──data_00004-of-00005.tfrecord
(可选)下载模型
官方提供了不少预训练模型,这里以Inception-ResNet-v2以例。
cd$WORKSPACE/checkpoints wgethttp://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz tarzxfinception_resnet_v2_2016_08_30.tar.gz
训练
读入数据
官方提供了读入Flowers数据集的代码models/slim/datasets/flowers.py,同样这里也是参考并修改成能读入上面定义的通用数据集。
把下面代码写入models/slim/datasets/dataset_classification.py。
importos importtensorflowastf slim=tf.contrib.slim defget_dataset(dataset_dir,num_samples,num_classes,labels_to_names_path=None,file_pattern='*.tfrecord'): file_pattern=os.path.join(dataset_dir,file_pattern) keys_to_features={ 'image/encoded':tf.FixedLenFeature((),tf.string,default_value=''), 'image/format':tf.FixedLenFeature((),tf.string,default_value='png'), 'image/class/label':tf.FixedLenFeature( [],tf.int64,default_value=tf.zeros([],dtype=tf.int64)), } items_to_handlers={ 'image':slim.tfexample_decoder.Image(), 'label':slim.tfexample_decoder.Tensor('image/class/label'), } decoder=slim.tfexample_decoder.TFExampleDecoder(keys_to_features,items_to_handlers) items_to_descriptions={ 'image':'Acolorimageofvaryingsize.', 'label':'Asingleintegerbetween0and'+str(num_classes-1), } labels_to_names=None iflabels_to_names_pathisnotNone: fd=open(labels_to_names_path) labels_to_names={i:line.strip()fori,lineinenumerate(fd)} fd.close() returnslim.dataset.Dataset( data_sources=file_pattern, reader=tf.TFRecordReader, decoder=decoder, num_samples=num_samples, items_to_descriptions=items_to_descriptions, num_classes=num_classes, labels_to_names=labels_to_names)
构建模型
官方提供了许多模型在models/slim/nets/。
如需要自定义模型,则参考官方提供的模型并放在对应的文件夹即可。
开始训练
官方提供了训练脚本,如果使用官方的数据读入和处理,可使用以下方式开始训练。
cd$WORKSPACE/models/slim CUDA_VISIBLE_DEVICES="0"pythontrain_image_classifier.py\ --train_dir=train_logs\ --dataset_name=flowers\ --dataset_split_name=train\ --dataset_dir=../../data/flowers\ --model_name=inception_resnet_v2\ --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt\ --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits\ --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits\ --max_number_of_steps=1000\ --batch_size=32\ --learning_rate=0.01\ --learning_rate_decay_type=fixed\ --save_interval_secs=60\ --save_summaries_secs=60\ --log_every_n_steps=10\ --optimizer=rmsprop\ --weight_decay=0.00004
不fine-tune把--checkpoint_path,--checkpoint_exclude_scopes和--trainable_scopes删掉。
fine-tune所有层把--checkpoint_exclude_scopes和--trainable_scopes删掉。
如果只使用CPU则加上--clone_on_cpu=True。
其它参数可删掉用默认值或自行修改。
使用自己的数据则需要修改models/slim/train_image_classifier.py:
把
fromdatasetsimportdataset_factory
修改为
fromdatasetsimportdataset_classification
把
dataset=dataset_factory.get_dataset( FLAGS.dataset_name,FLAGS.dataset_split_name,FLAGS.dataset_dir)
修改为
dataset=dataset_classification.get_dataset( FLAGS.dataset_dir,FLAGS.num_samples,FLAGS.num_classes,FLAGS.labels_to_names_path)
在
tf.app.flags.DEFINE_string( 'dataset_dir',None,'Thedirectorywherethedatasetfilesarestored.')
后加入
tf.app.flags.DEFINE_integer( 'num_samples',3320,'Numberofsamples.') tf.app.flags.DEFINE_integer( 'num_classes',5,'Numberofclasses.') tf.app.flags.DEFINE_string( 'labels_to_names_path',None,'Labelnamesfilepath.')
训练时执行以下命令即可:
cd$WORKSPACE/models/slim pythontrain_image_classifier.py\ --train_dir=train_logs\ --dataset_dir=../../data/train\ --num_samples=3320\ --num_classes=5\ --labels_to_names_path=../../data/labels.txt\ --model_name=inception_resnet_v2\ --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt\ --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits\ --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits
可视化log
可一边训练一边可视化训练的log,可看到Loss趋势。
tensorboard--logdirtrain_logs/
验证
官方提供了验证脚本。
pythoneval_image_classifier.py\ --checkpoint_path=train_logs\ --eval_dir=eval_logs\ --dataset_name=flowers\ --dataset_split_name=validation\ --dataset_dir=../../data/flowers\ --model_name=inception_resnet_v2
同样,如果是使用自己的数据集,则需要修改models/slim/eval_image_classifier.py:
把
fromdatasetsimportdataset_factory
修改为
fromdatasetsimportdataset_classification
把
dataset=dataset_factory.get_dataset( FLAGS.dataset_name,FLAGS.dataset_split_name,FLAGS.dataset_dir)
修改为
dataset=dataset_classification.get_dataset( FLAGS.dataset_dir,FLAGS.num_samples,FLAGS.num_classes,FLAGS.labels_to_names_path)
在
tf.app.flags.DEFINE_string( 'dataset_dir',None,'Thedirectorywherethedatasetfilesarestored.')
后加入
tf.app.flags.DEFINE_integer( 'num_samples',350,'Numberofsamples.') tf.app.flags.DEFINE_integer( 'num_classes',5,'Numberofclasses.') tf.app.flags.DEFINE_string( 'labels_to_names_path',None,'Labelnamesfilepath.')
验证时执行以下命令即可:
pythoneval_image_classifier.py\ --checkpoint_path=train_logs\ --eval_dir=eval_logs\ --dataset_dir=../../data/val\ --num_samples=350\ --num_classes=5\ --model_name=inception_resnet_v2
可以一边训练一边验证,,注意使用其它的GPU或合理分配显存。
同样也可以可视化log,如果已经在可视化训练的log则建议使用其它端口,如:
tensorboard--logdireval_logs/--port6007
测试
参考models/slim/eval_image_classifier.py,可编写读取图片用模型进行推导的脚本models/slim/test_image_classifier.py
from__future__importabsolute_import from__future__importdivision from__future__importprint_function importos importmath importtensorflowastf fromnetsimportnets_factory frompreprocessingimportpreprocessing_factory slim=tf.contrib.slim tf.app.flags.DEFINE_string( 'master','','TheaddressoftheTensorFlowmastertouse.') tf.app.flags.DEFINE_string( 'checkpoint_path','/tmp/tfmodel/', 'Thedirectorywherethemodelwaswrittentooranabsolutepathtoa' 'checkpointfile.') tf.app.flags.DEFINE_string( 'test_path','','Testimagepath.') tf.app.flags.DEFINE_integer( 'num_classes',5,'Numberofclasses.') tf.app.flags.DEFINE_integer( 'labels_offset',0, 'Anoffsetforthelabelsinthedataset.Thisflagisprimarilyusedto' 'evaluatetheVGGandResNetarchitectureswhichdonotuseabackground' 'classfortheImageNetdataset.') tf.app.flags.DEFINE_string( 'model_name','inception_v3','Thenameofthearchitecturetoevaluate.') tf.app.flags.DEFINE_string( 'preprocessing_name',None,'Thenameofthepreprocessingtouse.Ifleft' 'as`None`,thenthemodel_nameflagisused.') tf.app.flags.DEFINE_integer( 'test_image_size',None,'Evalimagesize') FLAGS=tf.app.flags.FLAGS defmain(_): ifnotFLAGS.test_list: raiseValueError('Youmustsupplythetestlistwith--test_list') tf.logging.set_verbosity(tf.logging.INFO) withtf.Graph().as_default(): tf_global_step=slim.get_or_create_global_step() #################### #Selectthemodel# #################### network_fn=nets_factory.get_network_fn( FLAGS.model_name, num_classes=(FLAGS.num_classes-FLAGS.labels_offset), is_training=False) ##################################### #Selectthepreprocessingfunction# ##################################### preprocessing_name=FLAGS.preprocessing_nameorFLAGS.model_name image_preprocessing_fn=preprocessing_factory.get_preprocessing( preprocessing_name, is_training=False) test_image_size=FLAGS.test_image_sizeornetwork_fn.default_image_size iftf.gfile.IsDirectory(FLAGS.checkpoint_path): checkpoint_path=tf.train.latest_checkpoint(FLAGS.checkpoint_path) else: checkpoint_path=FLAGS.checkpoint_path tf.Graph().as_default() withtf.Session()assess: image=open(FLAGS.test_path,'rb').read() image=tf.image.decode_jpeg(image,channels=3) processed_image=image_preprocessing_fn(image,test_image_size,test_image_size) processed_images=tf.expand_dims(processed_image,0) logits,_=network_fn(processed_images) predictions=tf.argmax(logits,1) saver=tf.train.Saver() saver.restore(sess,checkpoint_path) np_image,network_input,predictions=sess.run([image,processed_image,predictions]) print('{}{}'.format(FLAGS.test_path,predictions[0])) if__name__=='__main__': tf.app.run()
测试时执行以下命令即可:
pythontest_image_classifier.py\ --checkpoint_path=train_logs/\ --test_path=../../data/flower_photos/tulips/6948239566_0ac0a124ee_n.jpg\ --num_classes=5\ --model_name=inception_resnet_v2
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。