30秒轻松实现TensorFlow物体检测
Google发布了新的TensorFlow物体检测API,包含了预训练模型,一个发布模型的jupyternotebook,一些可用于使用自己数据集对模型进行重新训练的有用脚本。
使用该API可以快速的构建一些图片中物体检测的应用。这里我们一步一步来看如何使用预训练模型来检测图像中的物体。
首先我们载入一些会使用的库
importnumpyasnp importos importsix.moves.urllibasurllib importsys importtarfile importtensorflowastf importzipfile fromcollectionsimportdefaultdict fromioimportStringIO frommatplotlibimportpyplotasplt fromPILimportImage
接下来进行环境设置
%matplotlibinline sys.path.append("..")
物体检测载入
fromutilsimportlabel_map_util fromutilsimportvisualization_utilsasvis_util
准备模型
变量 任何使用export_inference_graph.py工具输出的模型可以在这里载入,只需简单改变PATH_TO_CKPT指向一个新的.pb文件。这里我们使用“移动网SSD”模型。
MODEL_NAME='ssd_mobilenet_v1_coco_11_06_2017' MODEL_FILE=MODEL_NAME+'.tar.gz' DOWNLOAD_BASE='http://download.tensorflow.org/models/object_detection/' PATH_TO_CKPT=MODEL_NAME+'/frozen_inference_graph.pb' PATH_TO_LABELS=os.path.join('data','mscoco_label_map.pbtxt') NUM_CLASSES=90
下载模型
opener=urllib.request.URLopener() opener.retrieve(DOWNLOAD_BASE+MODEL_FILE,MODEL_FILE) tar_file=tarfile.open(MODEL_FILE) forfileintar_file.getmembers(): file_name=os.path.basename(file.name) if'frozen_inference_graph.pb'infile_name: tar_file.extract(file,os.getcwd())
将(frozen)TensorFlow模型载入内存
detection_graph=tf.Graph() withdetection_graph.as_default(): od_graph_def=tf.GraphDef() withtf.gfile.GFile(PATH_TO_CKPT,'rb')asfid: serialized_graph=fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def,name='')
载入标签图
标签图将索引映射到类名称,当我们的卷积预测5时,我们知道它对应飞机。这里我们使用内置函数,但是任何返回将整数映射到恰当字符标签的字典都适用。
label_map=label_map_util.load_labelmap(PATH_TO_LABELS) categories=label_map_util.convert_label_map_to_categories(label_map,max_num_classes=NUM_CLASSES,use_display_name=True) category_index=label_map_util.create_category_index(categories)
辅助代码
defload_image_into_numpy_array(image): (im_width,im_height)=image.size returnnp.array(image.getdata()).reshape( (im_height,im_width,3)).astype(np.uint8)
检测
PATH_TO_TEST_IMAGES_DIR='test_images' TEST_IMAGE_PATHS=[os.path.join(PATH_TO_TEST_IMAGES_DIR,'image{}.jpg'.format(i))foriinrange(1,3)] IMAGE_SIZE=(12,8) [python]viewplaincopy withdetection_graph.as_default(): withtf.Session(graph=detection_graph)assess: forimage_pathinTEST_IMAGE_PATHS: image=Image.open(image_path) #这个array在之后会被用来准备为图片加上框和标签 image_np=load_image_into_numpy_array(image) #扩展维度,应为模型期待:[1,None,None,3] image_np_expanded=np.expand_dims(image_np,axis=0) image_tensor=detection_graph.get_tensor_by_name('image_tensor:0') #每个框代表一个物体被侦测到. boxes=detection_graph.get_tensor_by_name('detection_boxes:0') #每个分值代表侦测到物体的可信度. scores=detection_graph.get_tensor_by_name('detection_scores:0') classes=detection_graph.get_tensor_by_name('detection_classes:0') num_detections=detection_graph.get_tensor_by_name('num_detections:0') #执行侦测任务. (boxes,scores,classes,num_detections)=sess.run( [boxes,scores,classes,num_detections], feed_dict={image_tensor:image_np_expanded}) #图形化. vis_util.visualize_boxes_and_labels_on_image_array( image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8) plt.figure(figsize=IMAGE_SIZE) plt.imshow(image_np)
在载入模型部分可以尝试不同的侦测模型以比较速度和准确度,将你想侦测的图片放入TEST_IMAGE_PATHS中运行即可。
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。