您的位置:首页 > 编程语言 > Python开发

PyTorch读取Cifar数据集并显示图片

2017-05-31 16:35 686 查看

首先了解一下需要的几个类所在的package



from torchvision import transforms, datasets as ds
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

#transform = transforms.Compose是把一系列图片操作组合起来,比如减去像素均值等。
#DataLoader读入的数据类型是PIL.Image
#这里对图片不做任何处理,仅仅是把PIL.Image转换为torch.FloatTensor,从而可以被pytorch计算
transform = transforms.Compose(
[
transforms.ToTensor()
]
)


Step 1,得到
torch.utils.data.Dataset
实例。

torch.utils.data.Dataset
是一个抽象类,
CIFAR100
是它的一个实例化子类

train=True
,读取训练集;
train=False
,读取测试集

download=False
,不下载。如果为
True
,则先检查
root
下有无该数据集,如果没有就先下载。

train_set = ds.CIFAR100(root='.', train=True, transform=transform, target_transform=None, download=True)


Step 2,把Dataset封装成torch.utils.data.DataLoader

data_loader = DataLoader(dataset=train_set,
batch_size=1,
shuffle=False,
num_workers=2)

# # 生成torch.utils.data.DataLoaderIter
# # 不过DataLoaderIter它会被DataLoader自动创建并且调用,我们用不到
# data_iter = iter(data_loader)
# images, labels = next(data_iter)


step 3,从
DataLoader
里读取数据,并将图片显示出来。

注意:

1)使用
for...in...
循环读取数据的时候,会自动调用
DataLoader
里的
__next__()
函数

而且只能对
Tensor
实例进行迭代,所以之前的
transforms
必须最后加一个
transforms.ToTensor()


2)显示图片有两种方式:
Image.show()
plt.imshow(ndarray)


Image.show()
:

通过
transforms.ToPILImage()
FloatTensor
转化为
Image


plt.imshow(ndarray)


通过
FloatTensor.numpy()
转化为
ndarray
,再调用
plt.imshow()


to_pil_image = transforms.ToPILImage()
cnt = 0
for image,label in data_loader:
if cnt>=3:      # 只显示3张图片
break
print(label)    # 显示label

# 方法1:Image.show()
# transforms.ToPILImage()中有一句
# npimg = np.transpose(pic.numpy(), (1, 2, 0))
# 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
img = to_pil_image(image[0])
img.show()

# 方法2:plt.imshow(ndarray)
img = image[0]      # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
img = img.numpy()   # FloatTensor转为ndarray
img = np.transpose(img, (1,2,0))    # 把channel那一维放到最后

# 显示图片
plt.imshow(img)
plt.show()

cnt += 1


另外补一句
np.transpose()
的用法。

第一个参数是要transpose的图片;

第二个是shape。比如一个ndarray是
(channel, height, width)
,如果给第二个参数
(height, width,channel)
,就会把第0维
channel
整个搬到最后。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息