pytorch制作自己的LMDB数据操作示例
本文实例讲述了pytorch制作自己的LMDB数据操作。分享给大家供大家参考,具体如下:
前言
记录下pytorch里如何使用lmdb的code,自用
制作部分的Code
code就是ASTER里数据制作部分的代码改了点,aster_train.txt里面就算图片的完整路径每行一个,图片同目录下有同名的txt,里面记着jpg的标签
importos
importlmdb#installlmdbby"pipinstalllmdb"
importcv2
importnumpyasnp
fromtqdmimporttqdm
importsix
fromPILimportImage
importscipy.ioassio
fromtqdmimporttqdm
importre
defcheckImageIsValid(imageBin):
ifimageBinisNone:
returnFalse
imageBuf=np.fromstring(imageBin,dtype=np.uint8)
img=cv2.imdecode(imageBuf,cv2.IMREAD_GRAYSCALE)
imgH,imgW=img.shape[0],img.shape[1]
ifimgH*imgW==0:
returnFalse
returnTrue
defwriteCache(env,cache):
withenv.begin(write=True)astxn:
fork,vincache.items():
txn.put(k.encode(),v)
def_is_difficult(word):
assertisinstance(word,str)
returnnotre.match('^[\w]+$',word)
defcreateDataset(outputPath,imagePathList,labelList,lexiconList=None,checkValid=True):
"""
CreateLMDBdatasetforCRNNtraining.
ARGS:
outputPath:LMDBoutputpath
imagePathList:listofimagepath
labelList:listofcorrespondinggroundtruthtexts
lexiconList:(optional)listoflexiconlists
checkValid:iftrue,checkthevalidityofeveryimage
"""
assert(len(imagePathList)==len(labelList))
nSamples=len(imagePathList)
env=lmdb.open(outputPath,map_size=1099511627776)#最大空间1048576GB
cache={}
cnt=1
foriinrange(nSamples):
imagePath=imagePathList[i]
label=labelList[i]
iflen(label)==0:
continue
ifnotos.path.exists(imagePath):
print('%sdoesnotexist'%imagePath)
continue
withopen(imagePath,'rb')asf:
imageBin=f.read()
ifcheckValid:
ifnotcheckImageIsValid(imageBin):
print('%sisnotavalidimage'%imagePath)
continue
#数据库中都是二进制数据
imageKey='image-%09d'%cnt#9位数不足填零
labelKey='label-%09d'%cnt
cache[imageKey]=imageBin
cache[labelKey]=label.encode()
iflexiconList:
lexiconKey='lexicon-%09d'%cnt
cache[lexiconKey]=''.join(lexiconList[i])
ifcnt%1000==0:
writeCache(env,cache)
cache={}
print('Written%d/%d'%(cnt,nSamples))
cnt+=1
nSamples=cnt-1
cache['num-samples']=str(nSamples).encode()
writeCache(env,cache)
print('Createddatasetwith%dsamples'%nSamples)
defget_sample_list(txt_path:str):
withopen(txt_path,'r')asfr:
jpg_list=[x.strip()forxinfr.readlines()ifos.path.exists(x.replace('.jpg','.txt').strip())]
txt_content_list=[]
forjpginjpg_list:
label_path=jpg.replace('.jpg','.txt')
withopen(label_path,'r')asfr:
try:
str_tmp=fr.readline()
exceptUnicodeDecodeErrorase:
print(label_path)
raise(e)
txt_content_list.append(str_tmp.strip())
returnjpg_list,txt_content_list
if__name__=="__main__":
txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt'
lmdb_output_path='/home/gpu-server/project/aster/dataset/train'
imagePathList,labelList=get_sample_list(txt_path)
createDataset(lmdb_output_path,imagePathList,labelList)
读取部分
这里用的pytorch的dataloader,简单记录一下,人比较懒,代码就直接抄过来,不整理拆分了,重点看__getitem__
from__future__importabsolute_import
#importsys
#sys.path.append('./')
importos
#importmoxingasmox
importpickle
fromtqdmimporttqdm
fromPILimportImage,ImageFile
importnumpyasnp
importrandom
importcv2
importlmdb
importsys
importsix
importtorch
fromtorch.utilsimportdata
fromtorch.utils.dataimportsampler
fromtorchvisionimporttransforms
fromlib.utils.labelmapsimportget_vocabulary,labels2strs
fromlib.utilsimportto_numpy
ImageFile.LOAD_TRUNCATED_IMAGES=True
fromconfigimportget_args
global_args=get_args(sys.argv[1:])
ifglobal_args.run_on_remote:
importmoxingasmox
#moxing是一个分布式的框架跳过
classLmdbDataset(data.Dataset):
def__init__(self,root,voc_type,max_len,num_samples,transform=None):
super(LmdbDataset,self).__init__()
ifglobal_args.run_on_remote:
dataset_name=os.path.basename(root)
data_cache_url="/cache/%s"%dataset_name
ifnotos.path.exists(data_cache_url):
os.makedirs(data_cache_url)
ifmox.file.exists(root):
mox.file.copy_parallel(root,data_cache_url)
else:
raiseValueError("%snotexists!"%root)
self.env=lmdb.open(data_cache_url,max_readers=32,readonly=True)
else:
self.env=lmdb.open(root,max_readers=32,readonly=True)
assertself.envisnotNone,"cannotcreatelmdbfrom%s"%root
self.txn=self.env.begin()
self.voc_type=voc_type
self.transform=transform
self.max_len=max_len
self.nSamples=int(self.txn.get(b"num-samples"))
self.nSamples=min(self.nSamples,num_samples)
assertvoc_typein['LOWERCASE','ALLCASES','ALLCASES_SYMBOLS','DIGITS']
self.EOS='EOS'
self.PADDING='PADDING'
self.UNKNOWN='UNKNOWN'
self.voc=get_vocabulary(voc_type,EOS=self.EOS,PADDING=self.PADDING,UNKNOWN=self.UNKNOWN)
self.char2id=dict(zip(self.voc,range(len(self.voc))))
self.id2char=dict(zip(range(len(self.voc)),self.voc))
self.rec_num_classes=len(self.voc)
self.lowercase=(voc_type=='LOWERCASE')
def__len__(self):
returnself.nSamples
def__getitem__(self,index):
assertindex<=len(self),'indexrangeerror'
index+=1
img_key=b'image-%09d'%index
imgbuf=self.txn.get(img_key)
#由于Image.open需要一个类文件对象所以这里需要把二进制转为一个类文件对象
buf=six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
try:
img=Image.open(buf).convert('RGB')
#img=Image.open(buf).convert('L')
#img=img.convert('RGB')
exceptIOError:
print('Corruptedimagefor%d'%index)
returnself[index+1]
#reconitionlabels
label_key=b'label-%09d'%index
word=self.txn.get(label_key).decode()
ifself.lowercase:
word=word.lower()
##fillwiththepaddingtoken
label=np.full((self.max_len,),self.char2id[self.PADDING],dtype=np.int)
label_list=[]
forcharinword:
ifcharinself.char2id:
label_list.append(self.char2id[char])
else:
##addtheunknowntoken
print('{0}isoutofvocabulary.'.format(char))
label_list.append(self.char2id[self.UNKNOWN])
##addastoptoken
label_list=label_list+[self.char2id[self.EOS]]
assertlen(label_list)<=self.max_len
label[:len(label_list)]=np.array(label_list)
iflen(label)<=0:
returnself[index+1]
#labellength
label_len=len(label_list)
ifself.transformisnotNone:
img=self.transform(img)
returnimg,label,label_len
更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》
希望本文所述对大家Python程序设计有所帮助。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。