Pytorch 数据加载与数据预处理方式
数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况。
torchvision.datasets中的数据集
torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含__len_(获取数据集长度)和_getItem__(获取数据集中每一项)两个子方法。
Dataset源码如上,可以看到其中包含了两个没有实现的子方法,之后所有的Dataet类都继承该类,并根据数据情况定制这两个子方法的具体实现。
因此当我们需要加载自己的数据集的时候也可以借鉴这种方法,只需要继承torch.utils.data.Dataset类并重写init,len,以及getitem这三个方法即可。这样组着的类可以直接作为参数传入到torch.util.data.DataLoader中去。
以CIFAR10为例源码:
classtorchvision.datasets.CIFAR10(root,train=True,transform=None,target_transform=None,download=False)
root(string)–Rootdirectoryofdatasetwheredirectorycifar-10-batches-pyexistsorwillbesavedtoifdownloadissettoTrue. train(bool,optional)–IfTrue,createsdatasetfromtrainingset,otherwisecreatesfromtestset. transform(callable,optional)–Afunction/transformthattakesinanPILimageandreturnsatransformedversion.E.g,transforms.RandomCrop target_transform(callable,optional)–Afunction/transformthattakesinthetargetandtransformsit. download(bool,optional)–Iftrue,downloadsthedatasetfromtheinternetandputsitinrootdirectory.Ifdatasetisalreadydownloaded,itisnotdownloadedagain.
加载自己的数据集
对于torchvision.datasets中有两个不同的类,分别为DatasetFolder和ImageFolder,ImageFolder是继承自DatasetFolder。
下面我们通过源码来看一看folder文件中DatasetFolder和ImageFolder分别做了些什么
importtorch.utils.dataasdata
fromPILimportImage
importos
importos.path
defhas_file_allowed_extension(filename,extensions)://检查输入是否是规定的扩展名
"""Checksifafileisanallowedextension.
Args:
filename(string):pathtoafile
Returns:
bool:Trueifthefilenameendswithaknownimageextension
"""
filename_lower=filename.lower()
returnany(filename_lower.endswith(ext)forextinextensions)
deffind_classes(dir):
classes=[dfordinos.listdir(dir)ifos.path.isdir(os.path.join(dir,d))]//获取root目录下所有的文件夹名称
classes.sort()
class_to_idx={classes[i]:iforiinrange(len(classes))}//生成类别名称与类别id的对应Dictionary
returnclasses,class_to_idx
defmake_dataset(dir,class_to_idx,extensions):
images=[]
dir=os.path.expanduser(dir)//将~和~user转化为用户目录,对参数中出现~进行处理
fortargetinsorted(os.listdir(dir)):
d=os.path.join(dir,target)
ifnotos.path.isdir(d):
continue
forroot,_,fnamesinsorted(os.walk(d))://os.work包含三个部分,root代表该目录路径_代表该路径下的文件夹名称集合,fnames代表该路径下的文件名称集合
forfnameinsorted(fnames):
ifhas_file_allowed_extension(fname,extensions):
path=os.path.join(root,fname)
item=(path,class_to_idx[target])
images.append(item)//生成(训练样本图像目录,训练样本所属类别)的元组
returnimages//返回上述元组的列表
classDatasetFolder(data.Dataset):
"""Agenericdataloaderwherethesamplesarearrangedinthisway:::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Args:
root(string):Rootdirectorypath.
loader(callable):Afunctiontoloadasamplegivenitspath.
extensions(list[string]):Alistofallowedextensions.
transform(callable,optional):Afunction/transformthattakesin
asampleandreturnsatransformedversion.
E.g,``transforms.RandomCrop``forimages.
target_transform(callable,optional):Afunction/transformthattakes
inthetargetandtransformsit.
Attributes:
classes(list):Listoftheclassnames.
class_to_idx(dict):Dictwithitems(class_name,class_index).
samples(list):Listof(samplepath,class_index)tuples
"""
def__init__(self,root,loader,extensions,transform=None,target_transform=None):
classes,class_to_idx=find_classes(root)
samples=make_dataset(root,class_to_idx,extensions)
iflen(samples)==0:
raise(RuntimeError("Found0filesinsubfoldersof:"+root+"\n"
"Supportedextensionsare:"+",".join(extensions)))
self.root=root
self.loader=loader
self.extensions=extensions
self.classes=classes
self.class_to_idx=class_to_idx
self.samples=samples
self.transform=transform
self.target_transform=target_transform
def__getitem__(self,index):
"""
根据index获取sample返回值为(sample,target)元组,同时如果该类输入参数中有transform和target_transform,torchvision.transforms类型的参数时,将获取的元组分别执行transform和target_transform中的数据转换方法。
Args:
index(int):Index
Returns:
tuple:(sample,target)wheretargetisclass_indexofthetargetclass.
"""
path,target=self.samples[index]
sample=self.loader(path)
ifself.transformisnotNone:
sample=self.transform(sample)
ifself.target_transformisnotNone:
target=self.target_transform(target)
returnsample,target
def__len__(self):
returnlen(self.samples)
def__repr__(self)://定义输出对象格式其中和__str__的区别是__repr__无论是print输出还是直接输出对象自身都是以定义的格式进行输出,而__str__只有在print输出的时候会是以定义的格式进行输出
fmt_str='Dataset'+self.__class__.__name__+'\n'
fmt_str+='Numberofdatapoints:{}\n'.format(self.__len__())
fmt_str+='RootLocation:{}\n'.format(self.root)
tmp='Transforms(ifany):'
fmt_str+='{0}{1}\n'.format(tmp,self.transform.__repr__().replace('\n','\n'+''*len(tmp)))
tmp='TargetTransforms(ifany):'
fmt_str+='{0}{1}'.format(tmp,self.target_transform.__repr__().replace('\n','\n'+''*len(tmp)))
returnfmt_str
IMG_EXTENSIONS=['.jpg','.jpeg','.png','.ppm','.bmp','.pgm','.tif']
defpil_loader(path):
#openpathasfiletoavoidResourceWarning(https://github.com/python-pillow/Pillow/issues/835)
withopen(path,'rb')asf:
img=Image.open(f)
returnimg.convert('RGB')
defaccimage_loader(path):
importaccimage
try:
returnaccimage.Image(path)
exceptIOError:
#Potentiallyadecodingproblem,fallbacktoPIL.Image
returnpil_loader(path)
defdefault_loader(path):
fromtorchvisionimportget_image_backend
ifget_image_backend()=='accimage':
returnaccimage_loader(path)
else:
returnpil_loader(path)
classImageFolder(DatasetFolder):
"""Agenericdataloaderwheretheimagesarearrangedinthisway:::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root(string):Rootdirectorypath.
transform(callable,optional):Afunction/transformthattakesinanPILimage
andreturnsatransformedversion.E.g,``transforms.RandomCrop``
target_transform(callable,optional):Afunction/transformthattakesinthe
targetandtransformsit.
loader(callable,optional):Afunctiontoloadanimagegivenitspath.
Attributes:
classes(list):Listoftheclassnames.
class_to_idx(dict):Dictwithitems(class_name,class_index).
imgs(list):Listof(imagepath,class_index)tuples
"""
def__init__(self,root,transform=None,target_transform=None,
loader=default_loader):
super(ImageFolder,self).__init__(root,loader,IMG_EXTENSIONS,
transform=transform,
target_transform=target_transform)
self.imgs=self.samples
如果自己所要加载的数据组织形式如下
root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png
即不同类别的训练数据分别存储在不同的文件夹中,这些文件夹都在root(即形如D:/animals或者/usr/animals)路径下
classtorchvision.datasets.ImageFolder(root,transform=None,target_transform=None,loader=)
参数如下:
root(string)–Rootdirectorypath. transform(callable,optional)–Afunction/transformthattakesinanPILimageandreturnsatransformedversion.E.g,transforms.RandomCrop target_transform(callable,optional)–Afunction/transformthattakesinthetargetandtransformsit. loader–Afunctiontoloadanimagegivenitspath.就是上述源码中 __getitem__(index) Parameters:index(int)–Index Returns:(sample,target)wheretargetisclass_indexofthetargetclass. Returntype:tuple
可以通过torchvision.datasets.ImageFolder进行加载
img_data=torchvision.datasets.ImageFolder('D:/bnu/database/flower',
transform=transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
)
print(len(img_data))
data_loader=torch.utils.data.DataLoader(img_data,batch_size=20,shuffle=True)
print(len(data_loader))
对于所有的训练样本都在一个文件夹中同时有一个对应的txt文件每一行分别是对应图像的路径以及其所属的类别,可以参照上述class写出对应的加载类
defdefault_loader(path):
returnImage.open(path).convert('RGB')
classMyDataset(Dataset):
def__init__(self,txt,transform=None,target_transform=None,loader=default_loader):
fh=open(txt,'r')
imgs=[]
forlineinfh:
line=line.strip('\n')
line=line.rstrip()
words=line.split()
imgs.append((words[0],int(words[1])))
self.imgs=imgs
self.transform=transform
self.target_transform=target_transform
self.loader=loader
def__getitem__(self,index):
fn,label=self.imgs[index]
img=self.loader(fn)
ifself.transformisnotNone:
img=self.transform(img)
returnimg,label
def__len__(self):
returnlen(self.imgs)
train_data=MyDataset(txt='mnist_test.txt',transform=transforms.ToTensor())
data_loader=DataLoader(train_data,batch_size=100,shuffle=True)
print(len(data_loader))
DataLoader解析
位于torch.util.data.DataLoader中源代码
该接口的主要目的是将pytorch中已有的数据接口如torchvision.datasets.ImageFolder,或者自定义的数据读取接口转化按照
batch_size的大小封装为Tensor,即相当于在内置数据接口或者自定义数据接口的基础上增加一维,大小为batch_size的大小,
得到的数据在之后可以通过封装为Variable,作为模型的输出
__init__中所需的参数如下
1.datasettorch.utils.data.Dataset类的子类,可以是torchvision.datasets.ImageFolder等内置类,也可是继承了torch.utils.data.Dataset的自定义类 2.batch_size每一个batch中包含的样本个数,默认是1 3.shuffle一般在训练集中采用,默认是false,设置为true则每一个epoch都会将训练样本打乱 4.sampler训练样本选取策略,和shuffle是互斥的如果shuffle为true,该参数一定要为None 5.batch_samplerBatchSampler一次产生一个batch的indices,和sampler以及shuffle互斥,一般使用默认的即可 上述Sampler的源代码地址如下[源代码](https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py) 6.num_workers用于数据加载的线程数量默认为0即只有主线程用来加载数据 7.collate_fn用来聚合数据生成mini_batch
使用的时候一般为如下使用方法:
train_data=torch.utils.data.DataLoader(...) fori,(input,target)inenumerate(train_data): ...
循环取DataLoader中的数据会触发类中__iter__方法,查看源代码可知其中调用的方法为return_DataLoaderIter(self),因此需要查看DataLoaderIter这一内部类
classDataLoaderIter(object):
"IteratesonceovertheDataLoader'sdataset,asspecifiedbythesampler"
def__init__(self,loader):
self.dataset=loader.dataset
self.collate_fn=loader.collate_fn
self.batch_sampler=loader.batch_sampler
self.num_workers=loader.num_workers
self.pin_memory=loader.pin_memoryandtorch.cuda.is_available()
self.timeout=loader.timeout
self.done_event=threading.Event()
self.sample_iter=iter(self.batch_sampler)
ifself.num_workers>0:
self.worker_init_fn=loader.worker_init_fn
self.index_queue=multiprocessing.SimpleQueue()
self.worker_result_queue=multiprocessing.SimpleQueue()
self.batches_outstanding=0
self.worker_pids_set=False
self.shutdown=False
self.send_idx=0
self.rcvd_idx=0
self.reorder_dict={}
base_seed=torch.LongTensor(1).random_()[0]
self.workers=[
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset,self.index_queue,self.worker_result_queue,self.collate_fn,
base_seed+i,self.worker_init_fn,i))
foriinrange(self.num_workers)]
ifself.pin_memoryorself.timeout>0:
self.data_queue=queue.Queue()
self.worker_manager_thread=threading.Thread(
target=_worker_manager_loop,
args=(self.worker_result_queue,self.data_queue,self.done_event,self.pin_memory,
torch.cuda.current_device()))
self.worker_manager_thread.daemon=True
self.worker_manager_thread.start()
else:
self.data_queue=self.worker_result_queue
forwinself.workers:
w.daemon=True#ensurethattheworkerexitsonprocessexit
w.start()
_update_worker_pids(id(self),tuple(w.pidforwinself.workers))
_set_SIGCHLD_handler()
self.worker_pids_set=True
#primetheprefetchloop
for_inrange(2*self.num_workers):
self._put_indices()
以上这篇Pytorch数据加载与数据预处理方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。