pytorch使用(一)处理并加载自己的数据
2017-12-16 19:20
746 查看
pytorch使用:目录
pytorch使用(一)数据处理
个人认为,数据处理或许是在完成一篇论文中最耗费时间的,特别是大多情况下,需要在很多个库上做实验。pytorch官方支持很多库,使用torchvision来完成数据的处理,点这里可以看到支持的库并不是很多。在这里,我将结合一个实例说明如何使用pytorch来处理自己的数据,任务是一个分析双臂运动的,检测6个关节点的运动。输入是连续三帧的检测结果以及计算的光流,也就是
$3*6+2*2=22$张heatmap,输出是中间帧的检测结果,也就是6张heatmap。
把原始数据处理为模型使用的数据需要3步:transforms.Compose() torchvision.datasets torch.utils.data.DataLoader()分别可以理解为数据处理格式的定义、数据处理和数据加载。
1. 数据预处理torchvision.transforms
pytorch使用torchvision.transforms实现数据的预处理,包括中心化(torchvision.transforms.CenterCrop)、随机剪切(torchvision.transforms.RandomCrop)、正则化、图片变为Tensor、tensor变为图片等,建议整体浏览一下这一部分的官方手册,非常有用,数据处理很方便。先转换为张量,然后正则化:
import torchvision.transforms as transforms transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]) #img = transform(img)
2. 数据读取,构建Dataset子类
参考:http://blog.csdn.net/victoriaw/article/details/72356453如果想要使用自己的数据,则必须自己构建一个torch.utils.data.Dataset的子类去读取数据。我们的将数据列表放在
train.txt和
test.txt中,将不同类型的数据的路径放在
path.txt中,所以在类的init函数中有path_file和 list_file两个变量
在定义torch.utils.data.Dataset的子类时,必须重载的两个函数是len和getitem:
- len返回数据集的大小
- getitem实现数据集的下标索引,返回对应的图像和标记(不一定非得返回图像和标记,返回元组的长度可以是任意长,这由网络需要的数据决定)。
末尾有自己写的一个Dataset子类的定义文件。
3. 数据加载
torch.utils.data.DataLoader()函数,合成数据并且提供迭代访问。主要由两部分组成:
- dataset(Dataset)。输入加载的数据,就是上面的
MyDataset的实现。
- batch_size, shuffle, sampler, batch_sampler, num_worker, collate_fn, pin_memory, drop_last, timeout等参数,介绍几个比较常用的,这些在官方网站都有:
- batch-size。样本每个batch的大小,默认为1。 - shuffle。是否打乱数据,默认为False。 - num_workers。数据分为几个线程处理默认为0。 - sampler。定义一个方法来绘制样本数据,如果定义该方法,则不能使用shuffle。默认为False
使用:
import torch from datagen import MyDataset trainset = MyDataset(path_file=pathFile,list_file=trainList,numJoints = 6,type=False) trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=8) testset = MyDataset(path_file=pathFile,list_file=testList,numJoints = 6,type=False) testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=8)
以下是定义
class MyDataset文件
datagen.py, 其中有
__init__(self, path_file, list_file,numJoints,type)、
__getitem__(self, idx)、
__len__(self)三个函数,
__getitem__返回一个(22,256,256)的输入和一个(6,256,256)的标签。
''' Load data ''' import numpy as np from PIL import Image #import cv2 import torch import torch.utils.data as data import torchvision.transforms as transforms class MyDataset(data.Dataset): def __init__(self, path_file, list_file,numJoints,type): ''' Args: path_file: (str) heatmap and optical file location list_file: (str) path to index file. numJoints: (int) number of joints type: (boolean) use pose flow(true) or optical flow(false) ''' self.numJoints = numJoints # read heatmap and optical path with open(path_file) as f: paths = f.readlines() for path in paths: splited = path.strip().split() if splited[0]=='resPath': self.resPath = splited[1] elif splited[0]=='gtPath': self.gtPath = splited[1] elif splited[0]=='opticalFlowPath': self.opticalFlowPath = splited[1] elif splited[0]=='poseFlowPath': self.poseFlowPath = splited[1] if type: self.flowPath = self.poseFlowPath else: self.flowPath = self.opticalFlowPath #read list with open(list_file) as f: self.list = f.readlines() self.num_samples = len(self.list) def __getitem__(self, idx): ''' load heatmaps and optical flow and encode it to a 22 channels input and 6 channels output :param idx: (int) image index :return: input: a 22 channel input which integrate 2 optical flow and heatmaps of 3 image output: the ground truth ''' input = [] output = [] # load heatmaps of 3 image for im in range(3): for map in range(6): curResPath = self.resPath + self.list[idx].rstrip('\n') + str(im + 1) + '/' + str(map + 1) + '.bmp' heatmap = Image.open(curResPath) heatmap.load() heatmap = np.asarray(heatmap, dtype='float') / 255 input.append(heatmap) # load 2 flow for flow in range(2): curFlowXPath = self.flowPath + self.list[idx].rstrip('\n') + 'flowx/' + str(flow + 1) + '.jpg' flowX = Image.open(curFlowXPath) flowX.load() flowX = np.asarray(flowX, dtype='float') curFlowYPath = self.flowPath + self.list[idx].rstrip('\n') + 'flowy/' + str(flow + 1) + '.jpg' flowY = Image.open(curFlowYPath) flowY.load() flowY = np.asarray(flowY, dtype='float') input.append(flowX) input.append(flowY) # load groundtruth for map in range(6): curgtPath = self.resPath + self.list[idx].rstrip('\n') + str(2) + '/' + str(map + 1) + '.bmp' heatmap = Image.open(curResPath) heatmap.load() heatmap = np.asarray(heatmap, dtype='float') / 255 output.append(heatmap) input = torch.Tensor(input) output = torch.Tensor(output) return input,output def __len__(self): return self.num_samples
相关文章推荐
- Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理
- PyTorch使用并行GPU处理数据
- 关于使用PyTorch设置多线程(threads)进行数据读取而导致GPU显存始终不释放的问题
- Flex 3快速入门: 处理数据 使用 HTTPService 加载外部数据
- pytorch多GPU训练以及多线程加载数据
- 数字图像处理 CImage类的使用与封装(jpg png gif tif bmp等格式图像的加载、数据读写、保存等功能)
- 使用faster rcnn训练自己的数据(py-faster-rcnn )
- 使用py-faster-rcnn训练自己的数据
- Flex 3快速入门: 处理数据 使用 HTTPService 加载外部数据
- 在Android Studio上使用GSON+VOLLEY,秒处理网络数据成集合。感受框架的力量。搭配RecyclerView和SwipeRefreshLayout,实现底端加载更多,下拉刷新。
- pytorch: 准备、训练和测试自己的图片数据
- tensorflow处理自己的图像数据(不使用队列)
- 使用数据2分处理的通用分页存储过程 前半部分与后半部分数据访问时间相同,同等访问速度提高一倍
- 跟我一起学Windows Workflow Foundation(3)-----使用If/Else活动,定制活动处理工作流,使用事件传递数据
- 使用SqlBulkCopy类加载其他源数据到SQL表
- 使用asp.net 2.0的CreateUserwizard控件如何向自己的数据表中添加数据
- 自己写的使用聚集函数实现多行字串合并处理
- 使用Hibernate处理数据
- 在GridView中处理数据不使用Data Source Controls
- 使用 EL、JSTL 处理表单数据(转载)