pytorch制作自己的LMDB数据操作示例
本文实例讲述了pytorch制作自己的LMDB数据操作。分享给大家供大家参考,具体如下:
前言
记录下pytorch里如何使用lmdb的code,自用
制作部分的Code
code就是ASTER里数据制作部分的代码改了点,aster_train.txt里面就算图片的完整路径每行一个,图片同目录下有同名的txt,里面记着jpg的标签
importos importlmdb#installlmdbby"pipinstalllmdb" importcv2 importnumpyasnp fromtqdmimporttqdm importsix fromPILimportImage importscipy.ioassio fromtqdmimporttqdm importre defcheckImageIsValid(imageBin): ifimageBinisNone: returnFalse imageBuf=np.fromstring(imageBin,dtype=np.uint8) img=cv2.imdecode(imageBuf,cv2.IMREAD_GRAYSCALE) imgH,imgW=img.shape[0],img.shape[1] ifimgH*imgW==0: returnFalse returnTrue defwriteCache(env,cache): withenv.begin(write=True)astxn: fork,vincache.items(): txn.put(k.encode(),v) def_is_difficult(word): assertisinstance(word,str) returnnotre.match('^[\w]+$',word) defcreateDataset(outputPath,imagePathList,labelList,lexiconList=None,checkValid=True): """ CreateLMDBdatasetforCRNNtraining. ARGS: outputPath:LMDBoutputpath imagePathList:listofimagepath labelList:listofcorrespondinggroundtruthtexts lexiconList:(optional)listoflexiconlists checkValid:iftrue,checkthevalidityofeveryimage """ assert(len(imagePathList)==len(labelList)) nSamples=len(imagePathList) env=lmdb.open(outputPath,map_size=1099511627776)#最大空间1048576GB cache={} cnt=1 foriinrange(nSamples): imagePath=imagePathList[i] label=labelList[i] iflen(label)==0: continue ifnotos.path.exists(imagePath): print('%sdoesnotexist'%imagePath) continue withopen(imagePath,'rb')asf: imageBin=f.read() ifcheckValid: ifnotcheckImageIsValid(imageBin): print('%sisnotavalidimage'%imagePath) continue #数据库中都是二进制数据 imageKey='image-%09d'%cnt#9位数不足填零 labelKey='label-%09d'%cnt cache[imageKey]=imageBin cache[labelKey]=label.encode() iflexiconList: lexiconKey='lexicon-%09d'%cnt cache[lexiconKey]=''.join(lexiconList[i]) ifcnt%1000==0: writeCache(env,cache) cache={} print('Written%d/%d'%(cnt,nSamples)) cnt+=1 nSamples=cnt-1 cache['num-samples']=str(nSamples).encode() writeCache(env,cache) print('Createddatasetwith%dsamples'%nSamples) defget_sample_list(txt_path:str): withopen(txt_path,'r')asfr: jpg_list=[x.strip()forxinfr.readlines()ifos.path.exists(x.replace('.jpg','.txt').strip())] txt_content_list=[] forjpginjpg_list: label_path=jpg.replace('.jpg','.txt') withopen(label_path,'r')asfr: try: str_tmp=fr.readline() exceptUnicodeDecodeErrorase: print(label_path) raise(e) txt_content_list.append(str_tmp.strip()) returnjpg_list,txt_content_list if__name__=="__main__": txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt' lmdb_output_path='/home/gpu-server/project/aster/dataset/train' imagePathList,labelList=get_sample_list(txt_path) createDataset(lmdb_output_path,imagePathList,labelList)
读取部分
这里用的pytorch的dataloader,简单记录一下,人比较懒,代码就直接抄过来,不整理拆分了,重点看__getitem__
from__future__importabsolute_import #importsys #sys.path.append('./') importos #importmoxingasmox importpickle fromtqdmimporttqdm fromPILimportImage,ImageFile importnumpyasnp importrandom importcv2 importlmdb importsys importsix importtorch fromtorch.utilsimportdata fromtorch.utils.dataimportsampler fromtorchvisionimporttransforms fromlib.utils.labelmapsimportget_vocabulary,labels2strs fromlib.utilsimportto_numpy ImageFile.LOAD_TRUNCATED_IMAGES=True fromconfigimportget_args global_args=get_args(sys.argv[1:]) ifglobal_args.run_on_remote: importmoxingasmox #moxing是一个分布式的框架跳过 classLmdbDataset(data.Dataset): def__init__(self,root,voc_type,max_len,num_samples,transform=None): super(LmdbDataset,self).__init__() ifglobal_args.run_on_remote: dataset_name=os.path.basename(root) data_cache_url="/cache/%s"%dataset_name ifnotos.path.exists(data_cache_url): os.makedirs(data_cache_url) ifmox.file.exists(root): mox.file.copy_parallel(root,data_cache_url) else: raiseValueError("%snotexists!"%root) self.env=lmdb.open(data_cache_url,max_readers=32,readonly=True) else: self.env=lmdb.open(root,max_readers=32,readonly=True) assertself.envisnotNone,"cannotcreatelmdbfrom%s"%root self.txn=self.env.begin() self.voc_type=voc_type self.transform=transform self.max_len=max_len self.nSamples=int(self.txn.get(b"num-samples")) self.nSamples=min(self.nSamples,num_samples) assertvoc_typein['LOWERCASE','ALLCASES','ALLCASES_SYMBOLS','DIGITS'] self.EOS='EOS' self.PADDING='PADDING' self.UNKNOWN='UNKNOWN' self.voc=get_vocabulary(voc_type,EOS=self.EOS,PADDING=self.PADDING,UNKNOWN=self.UNKNOWN) self.char2id=dict(zip(self.voc,range(len(self.voc)))) self.id2char=dict(zip(range(len(self.voc)),self.voc)) self.rec_num_classes=len(self.voc) self.lowercase=(voc_type=='LOWERCASE') def__len__(self): returnself.nSamples def__getitem__(self,index): assertindex<=len(self),'indexrangeerror' index+=1 img_key=b'image-%09d'%index imgbuf=self.txn.get(img_key) #由于Image.open需要一个类文件对象所以这里需要把二进制转为一个类文件对象 buf=six.BytesIO() buf.write(imgbuf) buf.seek(0) try: img=Image.open(buf).convert('RGB') #img=Image.open(buf).convert('L') #img=img.convert('RGB') exceptIOError: print('Corruptedimagefor%d'%index) returnself[index+1] #reconitionlabels label_key=b'label-%09d'%index word=self.txn.get(label_key).decode() ifself.lowercase: word=word.lower() ##fillwiththepaddingtoken label=np.full((self.max_len,),self.char2id[self.PADDING],dtype=np.int) label_list=[] forcharinword: ifcharinself.char2id: label_list.append(self.char2id[char]) else: ##addtheunknowntoken print('{0}isoutofvocabulary.'.format(char)) label_list.append(self.char2id[self.UNKNOWN]) ##addastoptoken label_list=label_list+[self.char2id[self.EOS]] assertlen(label_list)<=self.max_len label[:len(label_list)]=np.array(label_list) iflen(label)<=0: returnself[index+1] #labellength label_len=len(label_list) ifself.transformisnotNone: img=self.transform(img) returnimg,label,label_len
更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》
希望本文所述对大家Python程序设计有所帮助。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。