Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
【源码GitHub地址】:点击进入
1.问题描述
之前写了一篇关于《pytorchDataset,DataLoader产生自定义的训练数据》的博客,但存在一个问题,我们不能在Dataset做一些数据清理,如果我们传递给Dataset数据,本身存在问题,那么迭代过程肯定出错的。
比如我把很多图片路径都传递给Dataset,如果图片路径都是正确的,且图片都存在也没有损坏,那显然运行是没有问题的;
但倘若传递给Dataset的图片路径有些图片是不存在,这时你通过Dataset读取图片数据,然后再迭代返回,就会出现类似如下的错误:
File"D:\ProgramData\Anaconda3\envs\pytorch-py36\lib\site-packages\torch\utils\data\_utils\collate.py",line68,in
return[default_collate(samples)forsamplesintransposed] File"D:\ProgramData\Anaconda3\envs\pytorch-py36\lib\site-packages\torch\utils\data\_utils\collate.py",line70,indefault_collate
raiseTypeError((error_msg_fmt.format(type(batch[0]))))TypeError:batchmustcontaintensors,numbers,dictsorlists;found
2.一般的解决方法
一般的解决方法也很简单粗暴,就是在传递数据给Dataset前,就做数据清理,把不存在的图片,损坏的数据都提前清理掉。
是的,这个是最简单粗暴的。
3.另一种解决方法:自定义返回数据的规则:collate_fn()校对函数
我们希望不管传递什么处理给Dataset,Dataset都进行处理,如果不存在或者异常,就返回None,而在DataLoader时,对于不存为None的数据,都去除掉。
这样就保证在迭代过程中,DataLoader获得batch数据都是正确的。
比如读取batch_size=5的图片数据,如果其中有1个(或者多个)图片是不存在,那么返回的batch应该把不存在的数据过滤掉,即返回5-1=4大小的batch的数据。
是的,我要实现的就是这个功能:返回的batch数据会自定清理掉不合法的数据。
3.1Pytorch数据处理函数:Dataset和DataLoader
Pytorch有两个数据处理函数:Dataset和DataLoader
fromtorch.utils.dataimportDataset,DataLoader
其中Dataset用于定义数据的读取和预处理操作,而DataLoader用于加载并产生批训练数据。
torch.utils.data.DataLoader参数说明:
DataLoader(object)可用参数:
1、dataset(Dataset)传入的数据集
2、batch_size(int,optional)每个batch有多少个样本
3、shuffle(bool,optional)在每个epoch开始的时候,对数据进行重新排序
4、sampler(Sampler,optional)自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
5、batch_sampler(Sampler,optional)与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutuallyexclusive)
6、num_workers(int,optional)这个参数决定了有几个进程来处理dataloading。0意味着所有的数据都会被load进主进程。(默认为0)
7、collate_fn(callable,optional)将一个list的sample组成一个mini-batch的函数
8、pin_memory(bool,optional)如果设置为True,那么dataloader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDApinnedmemory)中.
9、drop_last(bool,optional)如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了。如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
10、timeout(numeric,optional)如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
11、worker_init_fn(callable,optional)每个worker初始化函数IfnotNone,thiswillbecalledoneachworkersubprocesswiththeworkerid(anintin[0,num_workers-1])asinput,afterseedingandbeforedataloading.(default:None)
我们要用到的是collate_fn()回调函数
3.2自定义collate_fn()函数:
torch.utils.data.DataLoader的collate_fn()用于设置batch数据拼接方式,默认是default_collate函数,但当batch中含有None等数据时,默认的default_collate校队方法会出现错误。因此,我们需要自定义collate_fn()函数:
方法也很简单:只需在原来的default_collate函数中添加下面几句代码:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了。
#这里添加:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了 ifisinstance(batch,list): batch=[(image,image_id)for(image,image_id)inbatchifimageisnotNone] ifbatch==[]: return(None,None)
dataset_collate.py:
#-*-coding:utf-8-*- """ @Project:pytorch-learning-tutorials @File:dataset_collate.py @Author:panjq @E-mail:pan_jinquan@163.com @Date:2019-06-0717:09:13 """ r""""Containsdefinitionsofthemethodsusedbythe_DataLoaderIterworkersto collatesamplesfetchedfromdatasetintoTensor(s). These**needs**tobeinglobalscopesincePy2doesn'tsupportserializing staticmethods. """ importtorch importre fromtorch._siximportcontainer_abcs,string_classes,int_classes _use_shared_memory=False r"""Whethertousesharedmemoryindefault_collate""" np_str_obj_array_pattern=re.compile(r'[SaUO]') error_msg_fmt="batchmustcontaintensors,numbers,dictsorlists;found{}" numpy_type_map={ 'float64':torch.DoubleTensor, 'float32':torch.FloatTensor, 'float16':torch.HalfTensor, 'int64':torch.LongTensor, 'int32':torch.IntTensor, 'int16':torch.ShortTensor, 'int8':torch.CharTensor, 'uint8':torch.ByteTensor, } defcollate_fn(batch): ''' collate_fn(callable,optional):mergesalistofsamplestoformamini-batch. 该函数参考touch的default_collate函数,也是DataLoader的默认的校对方法,当batch中含有None等数据时, 默认的default_collate校队方法会出现错误 一种的解决方法是: 判断batch中image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了 :parambatch: :return: ''' r"""Putseachdatafieldintoatensorwithouterdimensionbatchsize""" #这里添加:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了 ifisinstance(batch,list): batch=[(image,image_id)for(image,image_id)inbatchifimageisnotNone] ifbatch==[]: return(None,None) elem_type=type(batch[0]) ifisinstance(batch[0],torch.Tensor): out=None if_use_shared_memory: #Ifwe'reinabackgroundprocess,concatenatedirectlyintoa #sharedmemorytensortoavoidanextracopy numel=sum([x.numel()forxinbatch]) storage=batch[0].storage()._new_shared(numel) out=batch[0].new(storage) returntorch.stack(batch,0,out=out) elifelem_type.__module__=='numpy'andelem_type.__name__!='str_'\ andelem_type.__name__!='string_': elem=batch[0] ifelem_type.__name__=='ndarray': #arrayofstringclassesandobject ifnp_str_obj_array_pattern.search(elem.dtype.str)isnotNone: raiseTypeError(error_msg_fmt.format(elem.dtype)) returncollate_fn([torch.from_numpy(b)forbinbatch]) ifelem.shape==():#scalars py_type=floatifelem.dtype.name.startswith('float')elseint returnnumpy_type_map[elem.dtype.name](list(map(py_type,batch))) elifisinstance(batch[0],float): returntorch.tensor(batch,dtype=torch.float64) elifisinstance(batch[0],int_classes): returntorch.tensor(batch) elifisinstance(batch[0],string_classes): returnbatch elifisinstance(batch[0],container_abcs.Mapping): return{key:collate_fn([d[key]fordinbatch])forkeyinbatch[0]} elifisinstance(batch[0],tuple)andhasattr(batch[0],'_fields'):#namedtuple returntype(batch[0])(*(collate_fn(samples)forsamplesinzip(*batch))) elifisinstance(batch[0],container_abcs.Sequence): transposed=zip(*batch)#ok return[collate_fn(samples)forsamplesintransposed] raiseTypeError((error_msg_fmt.format(type(batch[0]))))
测试方法:
#-*-coding:utf-8-*- """ @Project:pytorch-learning-tutorials @File:dataset.py @Author:panjq @E-mail:pan_jinquan@163.com @Date:2019-03-0718:45:06 """ importtorch fromtorch.autogradimportVariable fromtorchvisionimporttransforms fromtorch.utils.dataimportDataset,DataLoader importnumpyasnp fromutilsimportdataset_collate importos importcv2 fromPILimportImage defread_image(path,mode='RGB'): ''' :parampath: :parammode:RGBorL :return: ''' returnImage.open(path).convert(mode) classTorchDataset(Dataset): def__init__(self,image_id_list,image_dir,resize_height=256,resize_width=256,repeat=1,transform=None): ''' :paramfilename:数据文件TXT:格式:imge_name.jpglabel1_idlabe2_id :paramimage_dir:图片路径:image_dir+imge_name.jpg构成图片的完整路径 :paramresize_height为None时,不进行缩放 :paramresize_width为None时,不进行缩放, PS:当参数resize_height或resize_width其中一个为None时,可实现等比例缩放 :paramrepeat:所有样本数据重复次数,默认循环一次,当repeat为None时,表示无限循环换成shape=(C,H,W),并且归一化到[0.0,1.0]的torch.FloatTensor类型 #transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))#给定均值(R,G,B)方差(R,G,B),将会把Tensor正则化 ]) epoch_num=2#总样本循环次数 batch_size=5#训练时的一组数据的大小 train_data_nums=10 max_iterate=int((train_data_nums+batch_size-1)/batch_size*epoch_num)#总迭代次数 train_data=TorchDataset(image_id_list=image_id_list, image_dir=image_dir, resize_height=resize_height, resize_width=resize_width, repeat=1, transform=train_transform) #使用默认的default_collate会报错 #train_loader=DataLoader(dataset=train_data,batch_size=batch_size,shuffle=False) #使用自定义的collate_fn train_loader=DataLoader(dataset=train_data,batch_size=batch_size,shuffle=False,collate_fn=dataset_collate.collate_fn) #[1]使用epoch方法迭代,TorchDataset的参数repeat=1 forepochinrange(epoch_num): forstep,(batch_image,batch_label)inenumerate(train_loader): ifbatch_imageisNoneandbatch_labelisNone: print("batch_image:{},batch_label:{}".format(batch_image,batch_label)) continue image=batch_image[0,:] image=image.numpy()#image=np.array(image) image=image.transpose(1,2,0)#通道由[c,h,w]->[h,w,c] cv2.imshow("image",image) cv2.waitKey(2000) print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label)) #batch_x,batch_y=Variable(batch_x),Variable(batch_y)
输出结果说明:
batch_size=5,输入图片列表image_id_list=["1.jpg","ddd.jpg","111.jpg","3.jpg","4.jpg","5.jpg","6.jpg","7.jpg","8.jpg","9.jpg"],其中"ddd.jpg","111.jpg"是不存在的,resize_width=224,正常情况下返回的数据应该是torch.Size([5,3,224,224]),但由于"ddd.jpg","111.jpg"不存在,被过滤掉了,所以第一个batch的维度变为torch.Size([3,3,224,224])
[Errno2]Nosuchfileordirectory:'../dataset/test_images/images\\ddd.jpg'
[Errno2]Nosuchfileordirectory:'../dataset/test_images/images\\111.jpg'
batch_image.shape:torch.Size([3,3,224,224]),batch_label:('1.jpg','3.jpg','4.jpg')
batch_image.shape:torch.Size([5,3,224,224]),batch_label:('5.jpg','6.jpg','7.jpg','8.jpg','9.jpg')
[Errno2]Nosuchfileordirectory:'../dataset/test_images/images\\ddd.jpg'
[Errno2]Nosuchfileordirectory:'../dataset/test_images/images\\111.jpg'
batch_image.shape:torch.Size([3,3,224,224]),batch_label:('1.jpg','3.jpg','4.jpg')
batch_image.shape:torch.Size([5,3,224,224]),batch_label:('5.jpg','6.jpg','7.jpg','8.jpg','9.jpg')
以上为个人经验,希望能给大家一个参考,也希望大家多多支持毛票票。如有错误或未考虑完全的地方,望不吝赐教。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。