阿里天池-零基础入门CV赛事- 街景字符编码识别-数据读取与数据扩增
阿里天池-零基础入门CV赛事- 街景字符编码识别-数据读取与数据扩增
本次任务学习目标
1、学会用pytorch进行图像读取
2、学会数据扩增的方法以及用pytorch读取数据
图像读取
Python中常见的图像读取的库有两个,分别是Pillow和OpenCv,本文将用OpenCV进行一个数据的读取。
import cv2 img = cv2.imread('000000.png') # Opencv默认颜色通道顺序是BRG,转换一下 #可以转换别的 #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) #转换为灰度图 edges = cv2.Canny(img, 30, 70) cv2.imwrite('canny.jpg', edges)# Canny边缘检测 #显示图片 #cv2.imshow("image", img) # 显示图片,后面会讲解 #cv2.waitKey(0) #等待按键
读取一张数据集进行尝试,结果如下:
数据扩增
什么是数据扩增
数据扩增,主要集中在图像和视频(视频也主要是采用分帧图像实现的)领域,可以增加训练集的样本,同时也能缓解模型模型过拟合以及给我们的模型带来更强的泛华能力。
数据扩增的作用
1、增加训练的样本数量
2、扩展样本空间
常用的数据扩增方法
常见的方法包括反转、平移、缩放、亮度变化、裁剪、光照等外部影响、颜色变换、模糊、灰度等方法。其常用的相关代码如下:
transforms.CenterCrop #对图片中心进行裁剪 transforms.ColorJitter #对图像颜色的对比度、饱和度和零度进行变换 transforms.FiveCrop #对图像四个角和中心进行裁剪得到五分图像 transforms.Grayscale #对图像进行灰度变换 transforms.Pad #使用固定值进行像素填充 transforms.RandomAffine #随机仿射变换 transforms.RandomCrop #随机区域裁剪 transforms.RandomHorizontalFlip #随机水平翻转 transforms.RandomRotation #随机旋转 transforms.RandomVerticalFlip #随机垂直翻转
利用Pytorch读取数据
本次任务主要学习的是利用pytorch框架来读取赛题数据,它是通过dataset进行封装,再通过dataloder进行并行读取。所以我们只用重新加载数据读取的逻辑即可。代码如下:
import os, sys, glob, shutil, json import cv2 from PIL import Image import numpy as np import torch from torch.utils.data.dataset import Dataset import torchvision.transforms as transforms class SVHNDataset(Dataset): def __init__(self, img_path, img_label, transform=None): self.img_path = img_path self.img_label = img_label if transform is not None: self.transform = transform else: self.transform = None def __getitem__(self, index): img = Image.open(self.img_path[index]).convert('RGB') if self.transform is not None: img = self.transform(img) # 原始SVHN中类别10为数字0 lbl = np.array(self.img_label[index], dtype=np.int) lbl = list(lbl) + (5 - len(lbl)) * [10] return img, torch.from_numpy(np.array(lbl[:5])) def __len__(self): return len(self.img_path) train_path = glob.glob('../input/train/*.png') train_path.sort() train_json = json.load(open('../input/train.json')) train_label = [train_json[x]['label'] for x in train_json] data = SVHNDataset(train_path, train_label,transforms.Compose([ # 缩放到固定尺⼨寸 transforms.Resize((64, 128)), # 随机颜⾊色变换 transforms.ColorJitter(0.2, 0.2, 0.2), # 加⼊入随机旋转 transforms.RandomRotation(5), # 将图⽚片转换为pytorch 的tesntor # transforms.ToTensor(), # 对图像像素进⾏行行归⼀一化 # transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) ]))
接下来我们可以在定义好的Dataset的基础上构建一个Dataloder,dataset是对数据集的封装,提供了对数据样本进行读取的索引方式;而Dataloder是对Dataset的封装,提供批量读取的迭代读取。Dataloder的相关代码如下:
import os, sys, glob, shutil, json import cv2 from PIL import Image import numpy as np import torch from torch.utils.data.dataset import Dataset import torchvision.transforms as transforms class SVHNDataset(Dataset): def __init__(self, img_path, img_label, transform=None): self.img_path = img_path self.img_label = img_label if transform is not None: self.transform = transform else: self.transform = None def __getitem__(self, index): img = Image.open(self.img_path[index]).convert('RGB') if self.transform is not None: img = self.transform(img) # 原始SVHN中类别10为数字0 lbl = np.array(self.img_label[index], dtype=np.int) lbl = list(lbl) + (5 - len(lbl)) * [10] return img, torch.from_numpy(np.array(lbl[:5])) def __len__(self): return len(self.img_path) train_path = glob.glob('../input/train/*.png') train_path.sort() train_json = json.load(open('../input/train.json')) train_label = [train_json[x]['label'] for x in train_json] train_loader = torch.utils.data.DataLoader(SVHNDataset(train_path, train_label,ransforms.Compose([transforms.Resize((64, 128)),transforms.ColorJitter(0.3, 0.3, 0.2),transforms.RandomRotation(5),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])), batch_size=10, # 每批样本个数 shuffle=False, # 是否打乱顺序 num_workers=10, # 读取的线程个数) for data in train_loader: break
通过上述代码,我们便可以将数据按照批次进行获取,每批次调用Dataset读取单个样本进行拼接,此时的data的格式为:
torch.Size([10, 3, 64, 128]), torch.Size([10, 6])
前面是图像文件的参数,依次为batchsizechanelheight*width,后面的为字符标签。
Question
运行过程中,我遇到了如下的问题,但我运行到这句代码的时候:
for data in train_loader: break
会报错:BrokenPipeError: [Errno 32] Broken pipe
字面意思就是进程崩了,查阅了资料才发现,是由于我们torch.utils.data.DataLoader中的num_workers的问题。
train_loader = torch.utils.data.DataLoader(SVHNDataset(train_path, train_label,ransforms.Compose([transforms.Resize((64, 128)),transforms.ColorJitter(0.3, 0.3, 0.2),transforms.RandomRotation(5),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])), batch_size=10, # 每批样本个数 shuffle=False, # 是否打乱顺序 num_workers=10, # 读取的线程个数
出现上面报错的原因有可能是系统占用的线程过多了,所以会崩。将num_workers=0,就不会在报错了。等于0也就等于没有不调用多线程,也就意味着我们的运行速度会非常非常缓慢,每一轮迭代时,dataloader不再有自主加载数据到RAM这一步骤(因为没有worker了),而是在RAM中找batch,找不到时再加载相应的batch。
详情参考链接: link.
Reference
由Datawhale提供的《零基础入门CV实践教程》
- Datawhale 零基础入门CV赛事-Task2 数据读取与数据扩增
- 天池-街景字符编码识别2-数据读取与数据扩增
- Datawhale 零基础入门CV赛事-Task2 数据读取与数据扩增
- Datawhale 零基础入门CV赛事-Task2 数据读取与数据扩增
- 零基础入门CV赛事- 数据读取与数据扩增
- Datawhale 零基础入门CV - Task 02 数据读取与数据扩增
- Datawhale 零基础入门CV赛事-Task03:字符识别模型
- 零基础入门CV之街道字符识别(二)
- 【学习记录】day2 Task2 数据读取与数据扩增(Datawhale 零基础⼊⻔CV)
- Datawhale零基础⼊⻔CV-Task2 数据读取与数据扩增
- 零基础入门CV之街道字符识别(三)
- 零基础⼊⻔CV-Task2 数据读取与数据扩增
- Java基础(字符,字符数据写入磁盘和读取数据的过程,return用法,break,continue区别)
- Python入门(二)——IDE选择PyCharm,输入和输出,基础规范,数据类型和变量,常量,字符串和编码,格式化
- Datawhale 零基础入门CV赛事-Task1 赛题理解
- JAVA基础初探(十三)IO简介、字节流与字符流区别、带缓冲的字节/字符流读取数据、FileReader/FileWriter便捷类、Apache IO库使用说明
- 天池比赛:街景字符编码识别(T1赛题理解)
- Python入门(二)——IDE选择PyCharm,输入和输出,基础规范,数据类型和变量,常量,字符串和编码,格式化
- 天池-街景字符编码识别4-模型训练与验证
- Datawhale 零基础入门CV赛事-Task4 模型训练与验证