Pytorch 数据加载与数据预处理方式
数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况。
torchvision.datasets中的数据集
torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含__len_(获取数据集长度)和_getItem__(获取数据集中每一项)两个子方法。
Dataset源码如上,可以看到其中包含了两个没有实现的子方法,之后所有的Dataet类都继承该类,并根据数据情况定制这两个子方法的具体实现。
因此当我们需要加载自己的数据集的时候也可以借鉴这种方法,只需要继承torch.utils.data.Dataset类并重写init,len,以及getitem这三个方法即可。这样组着的类可以直接作为参数传入到torch.util.data.DataLoader中去。
以CIFAR10为例源码:
classtorchvision.datasets.CIFAR10(root,train=True,transform=None,target_transform=None,download=False)
root(string)–Rootdirectoryofdatasetwheredirectorycifar-10-batches-pyexistsorwillbesavedtoifdownloadissettoTrue. train(bool,optional)–IfTrue,createsdatasetfromtrainingset,otherwisecreatesfromtestset. transform(callable,optional)–Afunction/transformthattakesinanPILimageandreturnsatransformedversion.E.g,transforms.RandomCrop target_transform(callable,optional)–Afunction/transformthattakesinthetargetandtransformsit. download(bool,optional)–Iftrue,downloadsthedatasetfromtheinternetandputsitinrootdirectory.Ifdatasetisalreadydownloaded,itisnotdownloadedagain.
加载自己的数据集
对于torchvision.datasets中有两个不同的类,分别为DatasetFolder和ImageFolder,ImageFolder是继承自DatasetFolder。
下面我们通过源码来看一看folder文件中DatasetFolder和ImageFolder分别做了些什么
importtorch.utils.dataasdata fromPILimportImage importos importos.path defhas_file_allowed_extension(filename,extensions)://检查输入是否是规定的扩展名 """Checksifafileisanallowedextension. Args: filename(string):pathtoafile Returns: bool:Trueifthefilenameendswithaknownimageextension """ filename_lower=filename.lower() returnany(filename_lower.endswith(ext)forextinextensions) deffind_classes(dir): classes=[dfordinos.listdir(dir)ifos.path.isdir(os.path.join(dir,d))]//获取root目录下所有的文件夹名称 classes.sort() class_to_idx={classes[i]:iforiinrange(len(classes))}//生成类别名称与类别id的对应Dictionary returnclasses,class_to_idx defmake_dataset(dir,class_to_idx,extensions): images=[] dir=os.path.expanduser(dir)//将~和~user转化为用户目录,对参数中出现~进行处理 fortargetinsorted(os.listdir(dir)): d=os.path.join(dir,target) ifnotos.path.isdir(d): continue forroot,_,fnamesinsorted(os.walk(d))://os.work包含三个部分,root代表该目录路径_代表该路径下的文件夹名称集合,fnames代表该路径下的文件名称集合 forfnameinsorted(fnames): ifhas_file_allowed_extension(fname,extensions): path=os.path.join(root,fname) item=(path,class_to_idx[target]) images.append(item)//生成(训练样本图像目录,训练样本所属类别)的元组 returnimages//返回上述元组的列表 classDatasetFolder(data.Dataset): """Agenericdataloaderwherethesamplesarearrangedinthisway::: root/class_x/xxx.ext root/class_x/xxy.ext root/class_x/xxz.ext root/class_y/123.ext root/class_y/nsdf3.ext root/class_y/asd932_.ext Args: root(string):Rootdirectorypath. loader(callable):Afunctiontoloadasamplegivenitspath. extensions(list[string]):Alistofallowedextensions. transform(callable,optional):Afunction/transformthattakesin asampleandreturnsatransformedversion. E.g,``transforms.RandomCrop``forimages. target_transform(callable,optional):Afunction/transformthattakes inthetargetandtransformsit. Attributes: classes(list):Listoftheclassnames. class_to_idx(dict):Dictwithitems(class_name,class_index). samples(list):Listof(samplepath,class_index)tuples """ def__init__(self,root,loader,extensions,transform=None,target_transform=None): classes,class_to_idx=find_classes(root) samples=make_dataset(root,class_to_idx,extensions) iflen(samples)==0: raise(RuntimeError("Found0filesinsubfoldersof:"+root+"\n" "Supportedextensionsare:"+",".join(extensions))) self.root=root self.loader=loader self.extensions=extensions self.classes=classes self.class_to_idx=class_to_idx self.samples=samples self.transform=transform self.target_transform=target_transform def__getitem__(self,index): """ 根据index获取sample返回值为(sample,target)元组,同时如果该类输入参数中有transform和target_transform,torchvision.transforms类型的参数时,将获取的元组分别执行transform和target_transform中的数据转换方法。 Args: index(int):Index Returns: tuple:(sample,target)wheretargetisclass_indexofthetargetclass. """ path,target=self.samples[index] sample=self.loader(path) ifself.transformisnotNone: sample=self.transform(sample) ifself.target_transformisnotNone: target=self.target_transform(target) returnsample,target def__len__(self): returnlen(self.samples) def__repr__(self)://定义输出对象格式其中和__str__的区别是__repr__无论是print输出还是直接输出对象自身都是以定义的格式进行输出,而__str__只有在print输出的时候会是以定义的格式进行输出 fmt_str='Dataset'+self.__class__.__name__+'\n' fmt_str+='Numberofdatapoints:{}\n'.format(self.__len__()) fmt_str+='RootLocation:{}\n'.format(self.root) tmp='Transforms(ifany):' fmt_str+='{0}{1}\n'.format(tmp,self.transform.__repr__().replace('\n','\n'+''*len(tmp))) tmp='TargetTransforms(ifany):' fmt_str+='{0}{1}'.format(tmp,self.target_transform.__repr__().replace('\n','\n'+''*len(tmp))) returnfmt_str IMG_EXTENSIONS=['.jpg','.jpeg','.png','.ppm','.bmp','.pgm','.tif'] defpil_loader(path): #openpathasfiletoavoidResourceWarning(https://github.com/python-pillow/Pillow/issues/835) withopen(path,'rb')asf: img=Image.open(f) returnimg.convert('RGB') defaccimage_loader(path): importaccimage try: returnaccimage.Image(path) exceptIOError: #Potentiallyadecodingproblem,fallbacktoPIL.Image returnpil_loader(path) defdefault_loader(path): fromtorchvisionimportget_image_backend ifget_image_backend()=='accimage': returnaccimage_loader(path) else: returnpil_loader(path) classImageFolder(DatasetFolder): """Agenericdataloaderwheretheimagesarearrangedinthisway::: root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png Args: root(string):Rootdirectorypath. transform(callable,optional):Afunction/transformthattakesinanPILimage andreturnsatransformedversion.E.g,``transforms.RandomCrop`` target_transform(callable,optional):Afunction/transformthattakesinthe targetandtransformsit. loader(callable,optional):Afunctiontoloadanimagegivenitspath. Attributes: classes(list):Listoftheclassnames. class_to_idx(dict):Dictwithitems(class_name,class_index). imgs(list):Listof(imagepath,class_index)tuples """ def__init__(self,root,transform=None,target_transform=None, loader=default_loader): super(ImageFolder,self).__init__(root,loader,IMG_EXTENSIONS, transform=transform, target_transform=target_transform) self.imgs=self.samples
如果自己所要加载的数据组织形式如下
root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png
即不同类别的训练数据分别存储在不同的文件夹中,这些文件夹都在root(即形如D:/animals或者/usr/animals)路径下
classtorchvision.datasets.ImageFolder(root,transform=None,target_transform=None,loader=)
参数如下:
root(string)–Rootdirectorypath. transform(callable,optional)–Afunction/transformthattakesinanPILimageandreturnsatransformedversion.E.g,transforms.RandomCrop target_transform(callable,optional)–Afunction/transformthattakesinthetargetandtransformsit. loader–Afunctiontoloadanimagegivenitspath.就是上述源码中 __getitem__(index) Parameters:index(int)–Index Returns:(sample,target)wheretargetisclass_indexofthetargetclass. Returntype:tuple
可以通过torchvision.datasets.ImageFolder进行加载
img_data=torchvision.datasets.ImageFolder('D:/bnu/database/flower', transform=transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor()]) ) print(len(img_data)) data_loader=torch.utils.data.DataLoader(img_data,batch_size=20,shuffle=True) print(len(data_loader))
对于所有的训练样本都在一个文件夹中同时有一个对应的txt文件每一行分别是对应图像的路径以及其所属的类别,可以参照上述class写出对应的加载类
defdefault_loader(path): returnImage.open(path).convert('RGB') classMyDataset(Dataset): def__init__(self,txt,transform=None,target_transform=None,loader=default_loader): fh=open(txt,'r') imgs=[] forlineinfh: line=line.strip('\n') line=line.rstrip() words=line.split() imgs.append((words[0],int(words[1]))) self.imgs=imgs self.transform=transform self.target_transform=target_transform self.loader=loader def__getitem__(self,index): fn,label=self.imgs[index] img=self.loader(fn) ifself.transformisnotNone: img=self.transform(img) returnimg,label def__len__(self): returnlen(self.imgs) train_data=MyDataset(txt='mnist_test.txt',transform=transforms.ToTensor()) data_loader=DataLoader(train_data,batch_size=100,shuffle=True) print(len(data_loader))
DataLoader解析
位于torch.util.data.DataLoader中源代码
该接口的主要目的是将pytorch中已有的数据接口如torchvision.datasets.ImageFolder,或者自定义的数据读取接口转化按照
batch_size的大小封装为Tensor,即相当于在内置数据接口或者自定义数据接口的基础上增加一维,大小为batch_size的大小,
得到的数据在之后可以通过封装为Variable,作为模型的输出
__init__中所需的参数如下
1.datasettorch.utils.data.Dataset类的子类,可以是torchvision.datasets.ImageFolder等内置类,也可是继承了torch.utils.data.Dataset的自定义类 2.batch_size每一个batch中包含的样本个数,默认是1 3.shuffle一般在训练集中采用,默认是false,设置为true则每一个epoch都会将训练样本打乱 4.sampler训练样本选取策略,和shuffle是互斥的如果shuffle为true,该参数一定要为None 5.batch_samplerBatchSampler一次产生一个batch的indices,和sampler以及shuffle互斥,一般使用默认的即可 上述Sampler的源代码地址如下[源代码](https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py) 6.num_workers用于数据加载的线程数量默认为0即只有主线程用来加载数据 7.collate_fn用来聚合数据生成mini_batch
使用的时候一般为如下使用方法:
train_data=torch.utils.data.DataLoader(...) fori,(input,target)inenumerate(train_data): ...
循环取DataLoader中的数据会触发类中__iter__方法,查看源代码可知其中调用的方法为return_DataLoaderIter(self),因此需要查看DataLoaderIter这一内部类
classDataLoaderIter(object): "IteratesonceovertheDataLoader'sdataset,asspecifiedbythesampler" def__init__(self,loader): self.dataset=loader.dataset self.collate_fn=loader.collate_fn self.batch_sampler=loader.batch_sampler self.num_workers=loader.num_workers self.pin_memory=loader.pin_memoryandtorch.cuda.is_available() self.timeout=loader.timeout self.done_event=threading.Event() self.sample_iter=iter(self.batch_sampler) ifself.num_workers>0: self.worker_init_fn=loader.worker_init_fn self.index_queue=multiprocessing.SimpleQueue() self.worker_result_queue=multiprocessing.SimpleQueue() self.batches_outstanding=0 self.worker_pids_set=False self.shutdown=False self.send_idx=0 self.rcvd_idx=0 self.reorder_dict={} base_seed=torch.LongTensor(1).random_()[0] self.workers=[ multiprocessing.Process( target=_worker_loop, args=(self.dataset,self.index_queue,self.worker_result_queue,self.collate_fn, base_seed+i,self.worker_init_fn,i)) foriinrange(self.num_workers)] ifself.pin_memoryorself.timeout>0: self.data_queue=queue.Queue() self.worker_manager_thread=threading.Thread( target=_worker_manager_loop, args=(self.worker_result_queue,self.data_queue,self.done_event,self.pin_memory, torch.cuda.current_device())) self.worker_manager_thread.daemon=True self.worker_manager_thread.start() else: self.data_queue=self.worker_result_queue forwinself.workers: w.daemon=True#ensurethattheworkerexitsonprocessexit w.start() _update_worker_pids(id(self),tuple(w.pidforwinself.workers)) _set_SIGCHLD_handler() self.worker_pids_set=True #primetheprefetchloop for_inrange(2*self.num_workers): self._put_indices()
以上这篇Pytorch数据加载与数据预处理方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。