使用PyTorch将文件夹下的图片分为训练集和验证集实例
PyTorch提供了ImageFolder的类来加载文件结构如下的图片数据集:
root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png
使用这个类的问题在于无法将训练集(trainingdataset)和验证集(validationdataset)分开。我写了两个类来完成这个工作。
importos
importtorch
fromtorch.utils.dataimportDataset,DataLoader
fromtorchvision.transformsimportToTensor,Resize,Compose
fromPILimportImage
fromsklearn.model_selectionimporttrain_test_split
classImageFolderSplitter:
#imagesshouldbeplacedinfolderslike:
#--root
#----root\dogs
#----root\dogs\image1.png
#----root\dogs\image2.png
#----root\cats
#----root\cats\image1.png
#----root\cats\image2.png
#path:therootoftheimagefolder
def__init__(self,path,train_size=0.8):
self.path=path
self.train_size=train_size
self.class2num={}
self.num2class={}
self.class_nums={}
self.data_x_path=[]
self.data_y_label=[]
self.x_train=[]
self.x_valid=[]
self.y_train=[]
self.y_valid=[]
forroot,dirs,filesinos.walk(path):
iflen(files)==0andlen(dirs)>1:
fori,dir1inenumerate(dirs):
self.num2class[i]=dir1
self.class2num[dir1]=i
eliflen(files)>1andlen(dirs)==0:
category=""
forkeyinself.class2num.keys():
ifkeyinroot:
category=key
break
label=self.class2num[category]
self.class_nums[label]=0
forfile1infiles:
self.data_x_path.append(os.path.join(root,file1))
self.data_y_label.append(label)
self.class_nums[label]+=1
else:
raiseRuntimeError("pleasecheckthefolderstructure!")
self.x_train,self.x_valid,self.y_train,self.y_valid=train_test_split(self.data_x_path,self.data_y_label,shuffle=True,train_size=self.train_size)
defgetTrainingDataset(self):
returnself.x_train,self.y_train
defgetValidationDataset(self):
returnself.x_valid,self.y_valid
classDatasetFromFilename(Dataset):
#x:alistofimagefilefullpath
#y:alistofimagecategories
def__init__(self,x,y,transforms=None):
super(DatasetFromFilename,self).__init__()
self.x=x
self.y=y
iftransforms==None:
self.transforms=ToTensor()
else:
self.transforms=transforms
def__len__(self):
returnlen(self.x)
def__getitem__(self,idx):
img=Image.open(self.x[idx])
img=img.convert("RGB")
returnself.transforms(img),torch.tensor([[self.y[idx]]])
#testcode
#splitter=ImageFolderSplitter("for_test")
#transforms=Compose([Resize((51,51)),ToTensor()])
#x_train,y_train=splitter.getTrainingDataset()
#training_dataset=DatasetFromFilename(x_train,y_train,transforms=transforms)
#training_dataloader=DataLoader(training_dataset,batch_size=2,shuffle=True)
#x_valid,y_valid=splitter.getValidationDataset()
#validation_dataset=DatasetFromFilename(x_valid,y_valid,transforms=transforms)
#validation_dataloader=DataLoader(validation_dataset,batch_size=2,shuffle=True)
#forx,yintraining_dataloader:
#print(x.shape,y.shape)
更多的代码可以在我的Githubreop下找到。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。