您的位置:首页 > 其它

PyTorch读取Cifar数据集并显示图片(转载)

2017-11-22 21:28 357 查看
转自:原文链接

PyTorch创建DataLoader的流程

首先了解一下需要的几个类所在的package



from torchvision import transforms, datasets as ds
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

#transform = transforms.Compose是把一系列图片操作组合起来,比如减去像素均值等。
#DataLoader读入的数据类型是PIL.Image
#这里对图片不做任何处理,仅仅是把PIL.Image转换为torch.FloatTensor,从而可以被pytorch计算
transform = transforms.Compose(
[
transforms.ToTensor()
]
)


Step 1,得到torch.utils.data.Dataset实例。

torch.utils.data.Dataset是一个抽象类,CIFAR100是它的一个实例化子类

train=True,读取训练集;train=False,读取测试集

download=False,不下载。如果为True,则先检查root下有无该数据集,如果没有就先下载。

train_set = ds.CIFAR100(root='.', train=True, transform=transform, target_transform=None, download=True)


Step 2,把Dataset封装成torch.utils.data.DataLoader

data_loader = DataLoader(dataset=train_set,
batch_size=1,
shuffle=False,
num_workers=2)
# # 生成torch.utils.data.DataLoaderIter
# # 不过DataLoaderIter它会被DataLoader自动创建并且调用,我们用不到
# data_iter = iter(data_loader)
# images, labels = next(data_iter)


step 3,从DataLoader里读取数据,并将图片显示出来。

注意:

使用for…in…循环读取数据的时候,会自动调用DataLoader里的next()函数

而且只能对Tensor实例进行迭代,所以之前的transforms必须最后加一个transforms.ToTensor()

显示图片有两种方式:Image.show()和plt.imshow(ndarray)

Image.show():

通过transforms.ToPILImage()把FloatTensor转化为Image

plt.imshow(ndarray):

通过FloatTensor.numpy()转化为ndarray,再调用plt.imshow()

to_pil_image = transforms.ToPILImage()

cnt = 0

for image,label in data_loader:

if cnt>=3: # 只显示3张图片

break

print(label) # 显示label

# 方法1:Image.show()
# transforms.ToPILImage()中有一句
# npimg = np.transpose(pic.numpy(), (1, 2, 0))
# 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
img = to_pil_image(image[0])
img.show()

# 方法2:plt.imshow(ndarray)
img = image[0]      # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
img = img.numpy()   # FloatTensor转为ndarray
img = np.transpose(img, (1,2,0))    # 把channel那一维放到最后

# 显示图片
plt.imshow(img)
plt.show()

cnt += 1
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: