关于pytorch处理类别不平衡的问题
当训练样本不均匀时,我们可以采用过采样、欠采样、数据增强等手段来避免过拟合。今天遇到一个3d点云数据集合,样本分布极不均匀,正例与负例相差4-5个数量级。数据增强效果就不会太好了,另外过采样也不太合适,因为是空间数据,新增的点有可能会对真实分布产生未知影响。所以采用欠采样来缓解类别不平衡的问题。
下面的代码展示了如何使用WeightedRandomSampler来完成抽样。
numDataPoints=1000 data_dim=5 bs=100 #Createdummydatawithclassimbalance9to1 data=torch.FloatTensor(numDataPoints,data_dim) target=np.hstack((np.zeros(int(numDataPoints*0.9),dtype=np.int32), np.ones(int(numDataPoints*0.1),dtype=np.int32))) print'targettrain0/1:{}/{}'.format( len(np.where(target==0)[0]),len(np.where(target==1)[0])) class_sample_count=np.array( [len(np.where(target==t)[0])fortinnp.unique(target)]) weight=1./class_sample_count samples_weight=np.array([weight[t]fortintarget]) samples_weight=torch.from_numpy(samples_weight) samples_weight=samples_weight.double() sampler=WeightedRandomSampler(samples_weight,len(samples_weight)) target=torch.from_numpy(target).long() train_dataset=torch.utils.data.TensorDataset(data,target) train_loader=DataLoader( train_dataset,batch_size=bs,num_workers=1,sampler=sampler) fori,(data,target)inenumerate(train_loader): print"batchindex{},0/1:{}/{}".format( i, len(np.where(target.numpy()==0)[0]), len(np.where(target.numpy()==1)[0]))
核心部分为实际使用时替换下变量把sampler传递给DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采样点个数:
class_sample_count=np.array( [len(np.where(target==t)[0])fortinnp.unique(target)]) weight=1./class_sample_count samples_weight=np.array([weight[t]fortintarget]) samples_weight=torch.from_numpy(samples_weight) samples_weight=samples_weight.double() sampler=WeightedRandomSampler(samples_weight,len(samples_weight))
参考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2
以上这篇关于pytorch处理类别不平衡的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。