PyTorch读取Cifar数据集并显示图片(转载)
2017-11-22 21:28
357 查看
转自:原文链接
train=True,读取训练集;train=False,读取测试集
download=False,不下载。如果为True,则先检查root下有无该数据集,如果没有就先下载。
使用for…in…循环读取数据的时候,会自动调用DataLoader里的next()函数
而且只能对Tensor实例进行迭代,所以之前的transforms必须最后加一个transforms.ToTensor()
显示图片有两种方式:Image.show()和plt.imshow(ndarray)
Image.show():
通过transforms.ToPILImage()把FloatTensor转化为Image
plt.imshow(ndarray):
通过FloatTensor.numpy()转化为ndarray,再调用plt.imshow()
to_pil_image = transforms.ToPILImage()
cnt = 0
for image,label in data_loader:
if cnt>=3: # 只显示3张图片
break
print(label) # 显示label
PyTorch创建DataLoader的流程
首先了解一下需要的几个类所在的packagefrom torchvision import transforms, datasets as ds from torch.utils.data import DataLoader import matplotlib.pyplot as plt import numpy as np #transform = transforms.Compose是把一系列图片操作组合起来,比如减去像素均值等。 #DataLoader读入的数据类型是PIL.Image #这里对图片不做任何处理,仅仅是把PIL.Image转换为torch.FloatTensor,从而可以被pytorch计算 transform = transforms.Compose( [ transforms.ToTensor() ] )
Step 1,得到torch.utils.data.Dataset实例。
torch.utils.data.Dataset是一个抽象类,CIFAR100是它的一个实例化子类train=True,读取训练集;train=False,读取测试集
download=False,不下载。如果为True,则先检查root下有无该数据集,如果没有就先下载。
train_set = ds.CIFAR100(root='.', train=True, transform=transform, target_transform=None, download=True)
Step 2,把Dataset封装成torch.utils.data.DataLoader
data_loader = DataLoader(dataset=train_set, batch_size=1, shuffle=False, num_workers=2) # # 生成torch.utils.data.DataLoaderIter # # 不过DataLoaderIter它会被DataLoader自动创建并且调用,我们用不到 # data_iter = iter(data_loader) # images, labels = next(data_iter)
step 3,从DataLoader里读取数据,并将图片显示出来。
注意:使用for…in…循环读取数据的时候,会自动调用DataLoader里的next()函数
而且只能对Tensor实例进行迭代,所以之前的transforms必须最后加一个transforms.ToTensor()
显示图片有两种方式:Image.show()和plt.imshow(ndarray)
Image.show():
通过transforms.ToPILImage()把FloatTensor转化为Image
plt.imshow(ndarray):
通过FloatTensor.numpy()转化为ndarray,再调用plt.imshow()
to_pil_image = transforms.ToPILImage()
cnt = 0
for image,label in data_loader:
if cnt>=3: # 只显示3张图片
break
print(label) # 显示label
# 方法1:Image.show() # transforms.ToPILImage()中有一句 # npimg = np.transpose(pic.numpy(), (1, 2, 0)) # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维 img = to_pil_image(image[0]) img.show() # 方法2:plt.imshow(ndarray) img = image[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维 img = img.numpy() # FloatTensor转为ndarray img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后 # 显示图片 plt.imshow(img) plt.show() cnt += 1
相关文章推荐
- PyTorch读取Cifar数据集并显示图片的实例讲解
- 转载关于android高效显示图片的文章---From 移动微技
- 真缓存显示bmp图片(转载)
- 基于Drectshow的GetCurrentImage使用方法无法显示图片的解决方法(本文转载)
- (转载)Qt:拖拽图片到QLabel上并显示
- (转载)数据库存取图片并在MVC3中显示在View中
- 【转载】opencv2 一个窗口显示多幅图片(windows7和bunutu系统)
- 转载css3 图片圆形显示 如何CSS将正方形图片显示为圆形图片布局
- 【转载】Android 使用开源库StickyGridHeaders来实现带sections和headers的GridView显示本地图片效果
- IE浏览器下同一网页多图片显示的瓶颈与优化(转载,网站图片方面优化一篇堪称经典的文章)
- 【转载】从资源中获取图片并显示
- PyTorch读取Cifar数据集并显示图片
- 近期QQ空间转载的日志有图片显示自己的头像昵称QQ号等
- QT完成图片拖拽显示【本文转载自网络】
- 转载ECTouch1.0 修改后台广告管理中广告列表显示广告图片
- 【转载】android中如何显示图片局部或者不同区域
- 【转载】从数据库中读出图片并显示的示例代码
- 转载-当图片加载失败或者没有的情况下显示默认图片
- [转载]Android开发之--读取文件夹下图片生成略缩图并点击显示大图
- 解决IE6下的,不能显示透明PNG图片的问题(转载)