计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
pytorch做标准化利用transforms.Normalize(mean_vals,std_vals),其中常用数据集的均值方差有:
if'coco'inargs.dataset: mean_vals=[0.471,0.448,0.408] std_vals=[0.234,0.239,0.242] elif'imagenet'inargs.dataset: mean_vals=[0.485,0.456,0.406] std_vals=[0.229,0.224,0.225]
计算自己数据集图像像素的均值方差:
importnumpyasnp
importcv2
importrandom
#calculatemeansandstd
train_txt_path='./train_val_list.txt'
CNum=10000#挑选多少图片进行计算
img_h,img_w=32,32
imgs=np.zeros([img_w,img_h,3,1])
means,stdevs=[],[]
withopen(train_txt_path,'r')asf:
lines=f.readlines()
random.shuffle(lines)#shuffle,随机挑选图片
foriintqdm_notebook(range(CNum)):
img_path=os.path.join('./train',lines[i].rstrip().split()[0])
img=cv2.imread(img_path)
img=cv2.resize(img,(img_h,img_w))
img=img[:,:,:,np.newaxis]
imgs=np.concatenate((imgs,img),axis=3)
#print(i)
imgs=imgs.astype(np.float32)/255.
foriintqdm_notebook(range(3)):
pixels=imgs[:,:,i,:].ravel()#拉成一行
means.append(np.mean(pixels))
stdevs.append(np.std(pixels))
#cv2读取的图像格式为BGR,PIL/Skimage读取到的都是RGB不用转
means.reverse()#BGR-->RGB
stdevs.reverse()
print("normMean={}".format(means))
print("normStd={}".format(stdevs))
print('transforms.Normalize(normMean={},normStd={})'.format(means,stdevs))
以上这篇计算pytorch标准化(Normalize)所需要数据集的均值和方差实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。