解决Tensorflow sess.run导致的内存溢出问题
下面是调用模型进行批量测试的代码(出现溢出),开始以为导致溢出的原因是数据读入方式问题引起的,用了tf,PIL和cv等方式读入图片数据,发现越来越慢,内存占用飙升,调试时发现是sess.run这里出了问题(随着for循环进行速度越来越慢)。
#CreatesgraphfromsavedGraphDef create_graph(pb_path) #InittfSession config=tf.ConfigProto() config.gpu_options.allow_growth=True sess=tf.Session(config=config) init=tf.global_variables_initializer() sess.run(init) input_image_tensor=sess.graph.get_tensor_by_name("create_inputs/batch:0") output_tensor_name=sess.graph.get_tensor_by_name("conv6/out_1:0") forfilenameinos.listdir(image_dir): image_path=os.path.join(image_dir,filename) start=time.time() image_data=cv2.imread(image_path) image_data=cv2.resize(image_data,(w,h)) image_data_1=image_data-IMG_MEAN input_image=np.expand_dims(image_data_1,0) raw_output_up=tf.image.resize_bilinear(output_tensor_name,size=[h,w],align_corners=True) raw_output_up=tf.argmax(raw_output_up,axis=3) predict_img=sess.run(raw_output_up,feed_dict={input_image_tensor:input_image})#1,height,width predict_img=np.squeeze(predict_img)#height,width voc_palette=visual.make_palette(3) masked_im=visual.vis_seg(image_data,predict_img,voc_palette) cv2.imwrite("%s_pred.png"%(save_dir+filename.split(".")[0]),masked_im) print(time.time()-start) print(">>>>>>Done")
下面是解决溢出问题的代码(将部分代码放在for循环外)
#CreatesgraphfromsavedGraphDef create_graph(pb_path) #InittfSession config=tf.ConfigProto() config.gpu_options.allow_growth=True sess=tf.Session(config=config) init=tf.global_variables_initializer() sess.run(init) input_image_tensor=sess.graph.get_tensor_by_name("create_inputs/batch:0") output_tensor_name=sess.graph.get_tensor_by_name("conv6/out_1:0") ############################################################################################################## raw_output_up=tf.image.resize_bilinear(output_tensor_name,size=[h,w],align_corners=True) raw_output_up=tf.argmax(raw_output_up,axis=3) ############################################################################################################## forfilenameinos.listdir(image_dir): image_path=os.path.join(image_dir,filename) start=time.time() image_data=cv2.imread(image_path) image_data=cv2.resize(image_data,(w,h)) image_data_1=image_data-IMG_MEAN input_image=np.expand_dims(image_data_1,0) predict_img=sess.run(raw_output_up,feed_dict={input_image_tensor:input_image})#1,height,width predict_img=np.squeeze(predict_img)#height,width voc_palette=visual.make_palette(3) masked_im=visual.vis_seg(image_data,predict_img,voc_palette) cv2.imwrite("%s_pred.png"%(save_dir+filename.split(".")[0]),masked_im) print(time.time()-start) print(">>>>>>Done")
总结:
在迭代过程中,在sess.run的for循环中不要加入tensorflow一些op操作,会增加图节点,否则随着迭代的进行,tf的图会越来越大,最终导致溢出;
建议不要使用tf.gfile.FastGFile(image_path,'rb').read()读入数据(有可能会造成溢出),用opencv之类读取。
以上这篇解决Tensoflowsess.run导致的内存溢出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。