您的位置:首页 > 其它

阿里天池-零基础入门CV赛事- 街景字符编码识别-数据读取与数据扩增

2020-06-02 06:08 691 查看

阿里天池-零基础入门CV赛事- 街景字符编码识别-数据读取与数据扩增

  • 利用Pytorch读取数据
  • Question
  • Reference
  • 本次任务学习目标

    1、学会用pytorch进行图像读取
    2、学会数据扩增的方法以及用pytorch读取数据

    图像读取

    Python中常见的图像读取的库有两个,分别是Pillow和OpenCv,本文将用OpenCV进行一个数据的读取。

    import cv2
    img = cv2.imread('000000.png') # Opencv默认颜色通道顺序是BRG,转换一下
    #可以转换别的
    #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) #转换为灰度图
    edges = cv2.Canny(img, 30, 70) cv2.imwrite('canny.jpg', edges)# Canny边缘检测
    #显示图片
    #cv2.imshow("image", img) # 显示图片,后面会讲解
    #cv2.waitKey(0) #等待按键

    读取一张数据集进行尝试,结果如下:

    数据扩增

    什么是数据扩增

    数据扩增,主要集中在图像和视频(视频也主要是采用分帧图像实现的)领域,可以增加训练集的样本,同时也能缓解模型模型过拟合以及给我们的模型带来更强的泛华能力。

    数据扩增的作用

    1、增加训练的样本数量
    2、扩展样本空间

    常用的数据扩增方法

    常见的方法包括反转、平移、缩放、亮度变化、裁剪、光照等外部影响、颜色变换、模糊、灰度等方法。其常用的相关代码如下:

    transforms.CenterCrop #对图片中心进行裁剪
    transforms.ColorJitter #对图像颜色的对比度、饱和度和零度进行变换
    transforms.FiveCrop #对图像四个角和中心进行裁剪得到五分图像
    transforms.Grayscale #对图像进行灰度变换
    transforms.Pad #使用固定值进行像素填充
    transforms.RandomAffine #随机仿射变换
    transforms.RandomCrop #随机区域裁剪
    transforms.RandomHorizontalFlip #随机水平翻转
    transforms.RandomRotation #随机旋转
    transforms.RandomVerticalFlip #随机垂直翻转

    利用Pytorch读取数据

    本次任务主要学习的是利用pytorch框架来读取赛题数据,它是通过dataset进行封装,再通过dataloder进行并行读取。所以我们只用重新加载数据读取的逻辑即可。代码如下:

    import os, sys, glob, shutil, json
    import cv2
    
    from PIL import Image
    import numpy as np
    
    import torch
    from torch.utils.data.dataset import Dataset
    import torchvision.transforms as transforms
    
    class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
    self.img_path = img_path
    self.img_label = img_label
    if transform is not None:
    self.transform = transform
    else:
    self.transform = None
    
    def __getitem__(self, index):
    img = Image.open(self.img_path[index]).convert('RGB')
    
    if self.transform is not None:
    img = self.transform(img)
    # 原始SVHN中类别10为数字0
    lbl = np.array(self.img_label[index], dtype=np.int)
    lbl = list(lbl) + (5 - len(lbl)) * [10]
    
    return img, torch.from_numpy(np.array(lbl[:5]))
    
    def __len__(self):
    return len(self.img_path)
    
    train_path = glob.glob('../input/train/*.png')
    train_path.sort()
    train_json = json.load(open('../input/train.json'))
    train_label = [train_json[x]['label'] for x in train_json]
    
    data = SVHNDataset(train_path, train_label,transforms.Compose([
    # 缩放到固定尺⼨寸
    transforms.Resize((64, 128)),
    
    # 随机颜⾊色变换
    transforms.ColorJitter(0.2, 0.2, 0.2),
    
    # 加⼊入随机旋转
    transforms.RandomRotation(5),
    
    # 将图⽚片转换为pytorch 的tesntor
    # transforms.ToTensor(),
    
    # 对图像像素进⾏行行归⼀一化
    # transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]))

    接下来我们可以在定义好的Dataset的基础上构建一个Dataloder,dataset是对数据集的封装,提供了对数据样本进行读取的索引方式;而Dataloder是对Dataset的封装,提供批量读取的迭代读取。Dataloder的相关代码如下:

    import os, sys, glob, shutil, json
    import cv2
    
    from PIL import Image
    import numpy as np
    
    import torch
    from torch.utils.data.dataset import Dataset
    import torchvision.transforms as transforms
    
    class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
    self.img_path = img_path
    self.img_label = img_label
    if transform is not None:
    self.transform = transform
    else:
    self.transform = None
    
    def __getitem__(self, index):
    img = Image.open(self.img_path[index]).convert('RGB')
    if self.transform is not None:
    img = self.transform(img)
    # 原始SVHN中类别10为数字0
    lbl = np.array(self.img_label[index], dtype=np.int)
    lbl = list(lbl) + (5 - len(lbl)) * [10]
    
    return img, torch.from_numpy(np.array(lbl[:5]))
    
    def __len__(self):
    return len(self.img_path)
    
    train_path = glob.glob('../input/train/*.png')
    train_path.sort()
    train_json = json.load(open('../input/train.json'))
    train_label = [train_json[x]['label'] for x in train_json]
    
    train_loader = torch.utils.data.DataLoader(SVHNDataset(train_path, train_label,ransforms.Compose([transforms.Resize((64, 128)),transforms.ColorJitter(0.3, 0.3, 0.2),transforms.RandomRotation(5),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])),
    batch_size=10, # 每批样本个数
    shuffle=False, # 是否打乱顺序
    num_workers=10, # 读取的线程个数)
    
    for data in train_loader:
    break

    通过上述代码,我们便可以将数据按照批次进行获取,每批次调用Dataset读取单个样本进行拼接,此时的data的格式为:

    torch.Size([10, 3, 64, 128]), torch.Size([10, 6])

    前面是图像文件的参数,依次为batchsizechanelheight*width,后面的为字符标签。

    Question

    运行过程中,我遇到了如下的问题,但我运行到这句代码的时候:

    for data in train_loader:
    break

    会报错:BrokenPipeError: [Errno 32] Broken pipe
    字面意思就是进程崩了,查阅了资料才发现,是由于我们torch.utils.data.DataLoader中的num_workers的问题。

    train_loader = torch.utils.data.DataLoader(SVHNDataset(train_path, train_label,ransforms.Compose([transforms.Resize((64, 128)),transforms.ColorJitter(0.3, 0.3, 0.2),transforms.RandomRotation(5),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])),
    batch_size=10, # 每批样本个数
    shuffle=False, # 是否打乱顺序
    num_workers=10, # 读取的线程个数

    出现上面报错的原因有可能是系统占用的线程过多了,所以会崩。将num_workers=0,就不会在报错了。等于0也就等于没有不调用多线程,也就意味着我们的运行速度会非常非常缓慢,每一轮迭代时,dataloader不再有自主加载数据到RAM这一步骤(因为没有worker了),而是在RAM中找batch,找不到时再加载相应的batch。
    详情参考链接: link.

    Reference

    由Datawhale提供的《零基础入门CV实践教程》

    内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
    标签: 
    相关文章推荐