pytorch ImageFolder的覆写实例
在为数据分类训练分类器的时候,比如猫狗分类时,我们经常会使用pytorch的ImageFolder:
CLASStorchvision.datasets.ImageFolder(root,transform=None,target_transform=None,loader=
,is_valid_file=None)
使用可见pytorchtorchvision.ImageFolder的用法介绍
这里想实现的是如果想要覆写该函数,即能使用它的特性,又可以实现自己的功能
首先先分析下其源代码:
IMG_EXTENSIONS=['.jpg','.jpeg','.png','.ppm','.bmp','.pgm','.tif','.tiff','webp'] classImageFolder(DatasetFolder): """Agenericdataloaderwheretheimagesarearrangedinthisway::: root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png Args: root(string):Rootdirectorypath. transform(callable,optional):Afunction/transformthattakesinanPILimage andreturnsatransformedversion.E.g,``transforms.RandomCrop`` target_transform(callable,optional):Afunction/transformthattakesinthe targetandtransformsit. loader(callable,optional):Afunctiontoloadanimagegivenitspath. Attributes: classes(list):Listoftheclassnames. class_to_idx(dict):Dictwithitems(class_name,class_index). imgs(list):Listof(imagepath,class_index)tuples """ def__init__(self,root,transform=None,target_transform=None, loader=default_loader): super(ImageFolder,self).__init__(root,loader,IMG_EXTENSIONS, transform=transform, target_transform=target_transform) self.imgs=self.samples
ImageFolder的代码很简单,主要是继承了DatasetFolder:
defhas_file_allowed_extension(filename,extensions):
"""查看文件是否是支持的可扩展类型
Args:
filename(string):文件路径
extensions(iterableofstrings):可扩展类型列表,即能接受的图像文件类型
Returns:
bool:Trueifthefilenameendswithoneofgivenextensions
"""
filename_lower=filename.lower()
returnany(filename_lower.endswith(ext)forextinextensions)#返回True或False列表
defmake_dataset(dir,class_to_idx,extensions):
"""
返回形如[(图像路径,该图像对应的类别索引值),(),...]
"""
images=[]
dir=os.path.expanduser(dir)
fortargetinsorted(class_to_idx.keys()):
d=os.path.join(dir,target)
ifnotos.path.isdir(d):
continue
forroot,_,fnamesinsorted(os.walk(d)):#层层遍历文件夹,返回当前文件夹路径,存在的所有文件夹名,存在的所有文件名
forfnameinsorted(fnames):
ifhas_file_allowed_extension(fname,extensions):查看文件是否是支持的可扩展类型,是则继续
path=os.path.join(root,fname)
item=(path,class_to_idx[target])
images.append(item)
returnimages
classDatasetFolder(data.Dataset):
"""Agenericdataloaderwherethesamplesarearrangedinthisway:::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Args:
root(string):根目录路径
loader(callable):根据给定的路径来加载样本的可调用函数
extensions(list[string]):可扩展类型列表,即能接受的图像文件类型.
transform(callable,optional):用于样本的transform函数,然后返回样本transform后的版本
E.g,``transforms.RandomCrop``forimages.
target_transform(callable,optional):用于样本标签的transform函数
Attributes:
classes(list):类别名列表
class_to_idx(dict):项目(class_name,class_index)字典,如{'cat':0,'dog':1}
samples(list):(samplepath,class_index)元组列表,即(样本路径,类别索引)
targets(list):在数据集中每张图片的类索引值,为列表
"""
def__init__(self,root,loader,extensions,transform=None,target_transform=None):
classes,class_to_idx=self._find_classes(root)#得到类名和类索引,如['cat','dog']和{'cat':0,'dog':1}
#返回形如[(图像路径,该图像对应的类别索引值),(),...],即对每个图像进行标记
samples=make_dataset(root,class_to_idx,extensions)
iflen(samples)==0:
raise(RuntimeError("Found0filesinsubfoldersof:"+root+"\n"
"Supportedextensionsare:"+",".join(extensions)))
self.root=root
self.loader=loader
self.extensions=extensions
self.classes=classes
self.class_to_idx=class_to_idx
self.samples=samples
self.targets=[s[1]forsinsamples]#所有图像的类索引值组成的列表
self.transform=transform
self.target_transform=target_transform
def_find_classes(self,dir):
"""
在数据集中查找类文件夹。
Args:
dir(string):根目录路径
Returns:
返回元组:(classes,class_to_idx)即(类名,类索引),其中classes即相应的目录名,如['cat','dog'];class_to_idx为形如{类名:类索引}的字典,如{'cat':0,'dog':1}.
Ensures:
保证没有类名是另一个类目录的子目录
"""
ifsys.version_info>=(3,5):
#FasterandavailableinPython3.5andabove
classes=[d.namefordinos.scandir(dir)ifd.is_dir()]#获得根目录dir的所有第一层子目录名
else:
classes=[dfordinos.listdir(dir)ifos.path.isdir(os.path.join(dir,d))]#效果和上面的一样,只是版本不同方法不同
classes.sort()#然后对类名进行排序
class_to_idx={classes[i]:iforiinrange(len(classes))}#然后将类名和索引值一一对应的到相应字典,如{'cat':0,'dog':1}
returnclasses,class_to_idx#然后返回类名和类索引
def__getitem__(self,index):
"""
Args:
index(int):Index
Returns:
tuple:(sample,target)wheretargetisclass_indexofthetargetclass.
"""
path,target=self.samples[index]
sample=self.loader(path)#加载图片
ifself.transformisnotNone:
sample=self.transform(sample)
ifself.target_transformisnotNone:
target=self.target_transform(target)
returnsample,target
def__len__(self):
returnlen(self.samples)
def__repr__(self):
fmt_str='Dataset'+self.__class__.__name__+'\n'
fmt_str+='Numberofdatapoints:{}\n'.format(self.__len__())
fmt_str+='RootLocation:{}\n'.format(self.root)
tmp='Transforms(ifany):'
fmt_str+='{0}{1}\n'.format(tmp,self.transform.__repr__().replace('\n','\n'+''*len(tmp)))
tmp='TargetTransforms(ifany):'
fmt_str+='{0}{1}'.format(tmp,self.target_transform.__repr__().replace('\n','\n'+''*len(tmp)))
returnfmt_str
此时想要覆写ImageFolder,代码为:
classCustomImageFolder(ImageFolder): """ 为了得到两张图(其中一张是随机选取的)的图像和索引值信息 """ def__init__(self,root,transform=None): super(CustomImageFolder,self).__init__(root,transform) self.indices=range(len(self))#该文件夹中的长度 def__getitem__(self,index1): index2=random.choice(self.indices)#从[0,indices]中随机抽取一个数字,为了随机选取一张图 path1=self.imgs[index1][0]#此时的self.imgs等于self.samples,即内容为[(图像路径,该图像对应的类别索引值),(),...] label1=self.imgs[index1][1] path2=self.imgs[index2][0] label2=self.imgs[index2][1] img1=self.loader(path1) img2=self.loader(path2) ifself.transformisnotNone: img1=self.transform(img1) img2=self.transform(img2) returnimg1,img2,label1,label2
以上这篇pytorchImageFolder的覆写实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。