Pytorch的mean和std调查实例
如下所示:
#coding:utf-8 from__future__importprint_function importcopy importclick importcv2 importnumpyasnp importtorch fromtorch.autogradimportVariable fromtorchvisionimportmodels,transforms importmatplotlib.pyplotasplt importload_caffemodel importscipy.ioassio #ifmodelhasLSTM #torch.backends.cudnn.enabled=False imgpath='D:/ck/files_detected_face224/' imgname='S055_002_00000025.png'#anger image_path=imgpath+imgname mean_file=[0.485,0.456,0.406] std_file=[0.229,0.224,0.225] raw_image=cv2.imread(image_path)[...,::-1] print(raw_image.shape) raw_image=cv2.resize(raw_image,(224,)*2) image=transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=mean_file, std=std_file, #mean=mean_file, #std=std_file, ) ])(raw_image).unsqueeze(0) print(image.shape) convert_image1=image.numpy() convert_image1=np.squeeze(convert_image1)#3*224*224,C*H*W convert_image1=convert_image1*np.reshape(std_file,(3,1,1))+np.reshape(mean_file,(3,1,1)) convert_image1=np.transpose(convert_image1,(1,2,0))#H*W*C print(convert_image1.shape) convert_image1=convert_image1*255 diff=raw_image-convert_image1 err=np.max(diff) print(err) plt.imshow(np.uint8(convert_image1)) plt.show()
结论:
input_image=(raw_image/255-mean)./std
下面调查均值文件和方差文件是如何生成的:
mean_file=[0.485,0.456,0.406] std_file=[0.229,0.224,0.225]
#coding:utf-8 importmatplotlib.pyplotasplt importargparse importos importnumpyasnp importtorchvision importtorchvision.transformsastransforms dataset_names=('cifar10','cifar100','mnist') parser=argparse.ArgumentParser(description='PyTorchLab') parser.add_argument('-d','--dataset',metavar='DATA',default='cifar10',choices=dataset_names, help='datasettobeused:'+'|'.join(dataset_names)+'(default:cifar10)') args=parser.parse_args() data_dir=os.path.join('.',args.dataset) print(args.dataset) args.dataset='cifar10' ifargs.dataset=="cifar10": train_transform=transforms.Compose([transforms.ToTensor()]) train_set=torchvision.datasets.CIFAR10(root=data_dir,train=True,download=True,transform=train_transform) #print(vars(train_set)) print(train_set.train_data.shape) print(train_set.train_data.mean(axis=(0,1,2))/255) print(train_set.train_data.std(axis=(0,1,2))/255) #imshowimage train_data=train_set.train_data ind=100 img0=train_data[ind,...] ##testchannelnumber,intotal,thecorrectchannelis:RGB,notlikeBGRincaffe #errorproduce #b,g,r=cv2.split(img0) #img0=cv2.merge([r,g,b]) print(img0.shape) print(type(img0)) plt.imshow(img0) plt.show()#inshipinsea #img0=cv2.resize(img0,(224,224)) #cv2.imshow('img0',img0) #cv2.waitKey() elifargs.dataset=="cifar100": train_transform=transforms.Compose([transforms.ToTensor()]) train_set=torchvision.datasets.CIFAR100(root=data_dir,train=True,download=True,transform=train_transform) #print(vars(train_set)) print(train_set.train_data.shape) print(np.mean(train_set.train_data,axis=(0,1,2))/255) print(np.std(train_set.train_data,axis=(0,1,2))/255) elifargs.dataset=="mnist": train_transform=transforms.Compose([transforms.ToTensor()]) train_set=torchvision.datasets.MNIST(root=data_dir,train=True,download=True,transform=train_transform) #print(vars(train_set)) print(list(train_set.train_data.size())) print(train_set.train_data.float().mean()/255) print(train_set.train_data.float().std()/255)
结果:
cifar10 Filesalreadydownloadedandverified (50000,32,32,3) [0.491399680.482158410.44653091] [0.247032230.243485130.26158784] (32,32,3)
使用matlab检测是如何计算mean_file和std_file的:
%loadcifar10dataset data=load('cifar10_train_data.mat'); train_data=data.train_data; disp(size(train_data)); temp=mean(train_data,1); disp(size(temp)); train_data=double(train_data); %computemean_file mean_val=mean(mean(mean(train_data,1),2),3)/255; %computestd_file temp1=train_data(:,:,:,1); std_val1=std(temp1(:))/255; temp2=train_data(:,:,:,2); std_val2=std(temp2(:))/255; temp3=train_data(:,:,:,3); std_val3=std(temp3(:))/255; mean_val=squeeze(mean_val); std_val=[std_val1,std_val2,std_val3]; disp(mean_val); disp(std_val); %result:mean_val:[0.4914,0.4822,0.4465] %std_val:[0.2470,0.2435,0.2616]
均值计算的过程也可以遵循标准差的计算过程。为了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:
mean_val=mean(mean(mean(train_data,1),2),3)/255;
标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。
以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。