pytorch 批次遍历数据集打印数据的例子
我就废话不多说了,直接上代码吧!
fromosimportlistdir
importos
fromtimeimporttime
importtorch.utils.dataasdata
importtorchvision.transformsastransforms
fromtorch.utils.dataimportDataLoader
defprintProgressBar(iteration,total,prefix='',suffix='',decimals=1,length=100,
fill='=',empty='',tip='>',begin='[',end=']',done="[DONE]",clear=True):
percent=("{0:."+str(decimals)+"f}").format(100*(iteration/float(total)))
filledLength=int(length*iteration//total)
bar=fill*filledLength
ifiteration!=total:
bar=bar+tip
bar=bar+empty*(length-filledLength-len(tip))
display='\r{prefix}{begin}{bar}{end}{percent}%{suffix}'\
.format(prefix=prefix,begin=begin,bar=bar,end=end,percent=percent,suffix=suffix)
print(display,end=''),#commaafterprint()requiredforpython2
ifiteration==total:#printwithnewlineoncomplete
ifclear:#displaygivencompletemessagewithspacesto'erase'previousprogressbar
finish='\r{prefix}{done}'.format(prefix=prefix,done=done)
ifhasattr(str,'decode'):#handlepython2non-unicodestringsforproperlengthmeasure
finish=finish.decode('utf-8')
display=display.decode('utf-8')
clear=''*max(len(display)-len(finish),0)
print(finish+clear)
else:
print('')
classDatasetFromFolder(data.Dataset):
def__init__(self,image_dir):
super(DatasetFromFolder,self).__init__()
self.photo_path=os.path.join(image_dir,"a")
self.sketch_path=os.path.join(image_dir,"b")
self.image_filenames=[xforxinlistdir(self.photo_path)ifis_image_file(x)]
transform_list=[transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
self.transform=transforms.Compose(transform_list)
def__getitem__(self,index):
#LoadImage
input=load_img(os.path.join(self.photo_path,self.image_filenames[index]))
input=self.transform(input)
target=load_img(os.path.join(self.sketch_path,self.image_filenames[index]))
target=self.transform(target)
returninput,target
def__len__(self):
returnlen(self.image_filenames)
if__name__=='__main__':
dataset=DatasetFromFolder("./dataset/facades/train")
dataloader=DataLoader(dataset=dataset,num_workers=8,batch_size=1,shuffle=True)
total=len(dataloader)
forepochinrange(20):
t0=time()
fori,batchinenumerate(dataloader):
real_a,real_b=batch[0],batch[1]
printProgressBar(i+1,total+1,
length=20,
prefix='Epoch%s'%str(1),
suffix=',d_loss:%d'%1)
printProgressBar(total,total,
done='Epoch[%s]'%str(epoch)+
',time:%.2fs'%(time()-t0)
)
以上这篇pytorch批次遍历数据集打印数据的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。