pytorch 数据集图片显示方法
图片显示
pytorch载入的数据集是元组tuple形式,里面包括了数据及标签(train_data,label),其中的train_data数据可以转换为torch.Tensor形式,方便后面计算使用。
同样给一些刚入门的同学在使用载入的数据显示图片的时候带来一些难以理解的地方,这里主要是将Tensor与numpy转换的过程,理解了这些就可以就行转换了
CIAFA10数据集
首先载入数据集,这里做了一些数据处理,包括图片尺寸、数据归一化等
importtorch fromtorch.autogradimportVariable importmatplotlib.pyplotasplt importtorchvision.datasetsasdset importtorchvision.transformsastransforms fromautoencoderimportAutoEncoder importtorch.nnasnn importtorchvision importnumpyasnp dataset=dset.CIFAR10(root='../train/data',download=True, transform=transforms.Compose([ transforms.Scale(200), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)), transforms.Gray() ]))
在这里dataset是一个CIFAR10对象,(大家可以查看一下他的源代码)
方式一
dataset[1]=([torch.FloatTensorofsize1x200x200],9)
载入的第二个数据是个tensor格式,包含一个标签9
这里我们做的就是将torch.FloatTensor转换为numpy,然后显示
b=dataset[1][0].numpy() #取数据,不取标签
因为这里的b仍然是1*200*200的大小,所以要重新reshape一下,适合输出图像
plt.imshow(b.reshape(200,200),cmap='gray') plt.show()
然后可以显示图像了
方式二
利用torch的接口
img=torchvision.utils.make_grid(dataset[1][0]).numpy() plt.imshow(np.transpose(img,(1,2,0))) plt.show()
这用np.transpose是因为plt.imshow在显示时候输入的是(imgsize,imgsieze,channels),而这里得到的img是(3,200,200)的格式,所以进行了转换,才能显示
以上这篇pytorch数据集图片显示方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。