您的位置:首页 > 其它

Pytorch学习小记1:torch.utils.data.Dataset类和datasat.py文件的初读笔记

2019-08-29 16:40 3431 查看
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。 本文链接:https://blog.csdn.net/weixin_42214565/article/details/100136685

昨天在使用torch.utils.data.DataLoader类时遇到了一些问题,通过粗略的学习了解到了Pytorch的数据处理主要基于三个类:Dataset,DatasetLoader和DatasetLoaderIter,并且它们依次构成封装关系。关于他们之间的关系和解读,这篇文章总结得比较好理解:https://zhuanlan.zhihu.com/p/30934236

于是追根溯源找到了定义Dataset类的文档dataset.py学习了一下,顺便恶补一下我一个月突击学习的python编程基础...

1. 初读dataset.py

通读了一下dataset.py文件里的注释,了解到它定义了Dataset类,一个函数random_split()以及它的四个子类IterableDataset类、TensorDataset类、ConcatDataset类和Subset类;其中IterableDataset类又衍生出一个子类ChainDataset类。这里仅对他们做一个简要的介绍,主要理解各个类的作用是什么,以及注释表明的注意事项,具体细节日后再做学习。

1.1 Dataset类

Dataset类的描述如下:

[code]r"""An abstract class representing a :class:`Dataset`.

All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.

.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices.  To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""

Dataset类是一个描述数据集的抽象类,可供我们定义自己的数据集。只要是以“标签-数据”保存的数据集,要想利用pytorch框架进行处理,都需要继承Dataset类来定义,原文的描述是:所有以键和数据样本构成映射的数据集都需要继承Dataset类。

要想通过继承定义自己的数据集,需要改写数据集中的成员函数__getitem__()以构建键->样本的索引形式。同样可以改写成员函数__len__()用来实现返回自己定义的数据集的大小

1.2 IterableDataset类

IterableDataset类是一个可迭代的Dataset类,定义中与Dataset不同的是它的__add__()返回值返回了一个ChainDataset()数据集,而Dataset返回的是ContactDataset().当数据集的数据样本需要做迭代处理的时候,需要继承IterableDataset类,尤其是对数据流的处理非常有用。

当基于IterableDataset定义的数据集被Dataloader处理的时候,Dataloader会为它的成员产生一个迭代器。(这也解开了之前我对TensorDataset被Dataloader处理之后进行迭代时报错的疑惑,由于TensorDataset直接继承Dataset类,不可迭代,所以对Dataloader迭代操作的时候会报错TypeError: 'DataLoader' object is not an iterator)

此外,注释还对多线程处理数据进行了补充说明,给了两个例子,还是直接上代码比较清楚:

[code]Example 1: splitting workload across all workers in :meth:`__iter__`::

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]

>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]

Example 2: splitting workload across all workers using :attr:`worker_init_fn`::

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.Dat
3ff7
aLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]

这段程序定义可迭代类MyIterableDataset用于实现输入起止值,保存起止值之间的所有自然数(包括开始数但不包括结束的数)。如果不像Example1 一样在成员函数__iter__中定义多线程处理的方案的话,直接丢给Dataloader处理的时候,如果定义多线程,会把全部数据存入各个指定的线程处理,那么结果就会重复输出(见Example2 中定义worker_init_fn()函数之前的代码)。

1.3 TensorDataset类

TensorDataset类专门用来存储张量,它可以沿第一维度的数据索引张量的每个样本,所以我们载入的张量的第一维数据的大小就必须要相同才行。

这也解释了为什么之前学习过程中载入数据时,要先对数据做一个torch.unsqueeze()处理了。当数据只有一维时,是没法通过一维数据索引的(因为仅有的一个维度就是样本本身),程序就会报错。

1.4 ConcatDataset类

ConcatDataset类是一个串联数据集,可以用来组合不同的现有数据集。具体实现办法可以从构造函数中看出:如果传入的数据集非空->那么就把数据集封装进一个列表->传入成员self.dataset->确保成员不可迭代之后->更新数据集大小。从构造函数中也可以看出,这个类同样不支持迭代。(同样解释了IterableDataset类和Dataset类的__add__()函数调用的类不一样的原因)

看代码就可以大概推测出这个类主要就是用于拼接不同数据集的。

1.5 ChainDataset类

理解了ConcatDataset,同样可以类比出ChainDataset的大概含义。ChainDataset就是用来链接多个可迭代数据集(IterableDataset类及其派生类)的。值得一提的是,链接操作是即时完成的,所以适用于处理大规模的数据流。构造时需要输入待链接的数据集,很好理解。

1.6 Subset类

Subset类用于存储数据集和索引。构造参数为数据集(母集)及目标索引,该类的成员函数__getitem__()可以根据输入索引返回索引对应的数据集。

1.7 random_split()函数

用于随机将数据集拆分为给定长度的非重叠数据集。通过使用函数加深一下理解:

[code]# 制作伪数据用于实验
x = torch.unsqueeze(torch.linspace(-1, 1, 10), dim=1)
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))
print(x)
print(y)

# 载入torchDataset
torch_dataset = Data.TensorDataset(x, y)

# 制作迭代用的列表,含义是:每次随机抽取两个样本长度的数据组成队列
list1 = []
for i in range(torch_dataset.__len__()//2):
list1.append(2)

# 随即拆分并打印结果
subset = Data.random_split(torch_dataset, list1)
for i in range(subset.__len__()):
print(subset.__getitem__(i)[1])

输出结果是:

[code]tensor([[-1.0000],
[-0.7778],
[-0.5556],
[-0.3333],
[-0.1111],
[ 0.1111],
[ 0.3333],
[ 0.5556],
[ 0.7778],
[ 1.0000]])
tensor([[1.0000],
[0.6049],
[0.3086],
[0.1111],
[0.0123],
[0.0123],
[0.1111],
[0.3086],
[0.6049],
[1.0000]])
(tensor([-0.1111]), tensor([0.0123]))
(tensor([-0.5556]), tensor([0.3086]))
(tensor([0.5556]), tensor([0.3086]))
(tensor([0.7778]), tensor([0.6049]))
(tensor([-0.7778]), tensor([0.6049]))

输入数据中,x是-1到1之间均等分割出来的10个数,y是x对应数的平方,程序把x和y组合到了TensorDataset里,并进行随机分割,分割成了五组数据集,并且能看出来是随机分割的。

需要补充说明的是,random_split()函数对输入数据要求非常苛刻,分割索引必须以可迭代对象表示,并且对象所有的成员之和必须等于输入数据集的和,否则会报错:

ValueError: Su(subset[1])m of input lengths does not equal the length of the input dataset!

2. torch.utils.data.Dataset类学习

通读一遍dataset.py之后对Dataset类的作用和基本用法有了大概的了解,接下来对源码进行学习并做总结。

2.1 基本用法

这部分看一下pytorch文档就非常清楚了:https://pytorch-cn.readthedocs.io/zh/latest/package_references/data/

Dataset参数:

  • data_tensor (Tensor) - 包含样本数据
  • target_tensor (Tensor) - 包含样本目标(标签)

此外还可以基于Dataset类自定义自己的数据集,具体的方法有博文描述得非常清楚了,这里不再赘述:https://www.geek-share.com/detail/2702283863.html

2.2 代码学习

Dataset类得定义代码非常简短:

[code]class Dataset(object):

def __getitem__(self, index):
raise NotImplementedError

def __add__(self, other):
return ConcatDataset([self, other])

首先它是一个类,关于定义的时候为什么需要继承“object”我还专门去了解了一下,据说是因为python2版本的遗留问题:当一个类B继承了母类A,并且又派生出了子类C,并且类A和B分别定义了同一个函数func()时,继承了“全部家当”的C类在调用函数func()时会产生歧义:时调用A.func()还是B.func()呢?经典的类(没有继承object)会采用深度优先的搜索策略去调用A.func();而新式类(object)会采用广度优先的搜索策略调用B.func()。基本就是这么一个差异,更详细的分析可以参考:https://www.zhihu.com/question/19754936

其次关于成员函数__getitem__()的理解,这个函数主要是根据索引来返回索引所指向的数据样本的。目前函数体只有raise 语句,说明这个函数不能直接使用,在构造自己的数据集时需要根据自己的数据类型自定义__getitem__()函数,否则程序在执行这一语句时会报错NotImplementedError.

具体的使用我用TensorDataset尝试了一下(因为TensorDataset类时Dataset类的子类,并且已经定义好了自己的__getitem__()函数),代码如下:

[code]# 制作伪数据用于实验
x = torch.unsqueeze(torch.linspace(-1, 1, 10), dim=1)
y = x.pow(2)

# 载入torchDataset
torch_dataset = Data.TensorDataset(x, y)

# 展示数据内容
print(torch_dataset)
for i in range(torch_dataset.__len__()):
print(torch_dataset.__getitem__(i))

数据还是之前测试random_split()的数据,输出结果为:

[code]<torch.utils.data.dataset.TensorDataset object at 0x0000024D927F2160>
(tensor([-1.]), tensor([1.]))
(tensor([-0.7778]), tensor([0.6049]))
(tensor([-0.5556]), tensor([0.3086]))
(tensor([-0.3333]), tensor([0.1111]))
(tensor([-0.1111]), tensor([0.0123]))
(tensor([0.1111]), tensor([0.0123]))
(tensor([0.3333]), tensor([0.1111]))
(tensor([0.5556]), tensor([0.3086]))
(tensor([0.7778]), tensor([0.6049]))
(tensor([1.]), tensor([1.]))

第一行是print(torch_dataset)的数据,可见直接输出只会给出数据集的地址;而采用成员函数__getitem__()输出便可得到对应的数据。其实还有一个更简单的实现方法,就是直接用数组的形式索引:

[code]for i in range(torch_dataset.__len__()):
print(torch_dataset[i])

结果是一样的~

此外还有一个有趣的地方是,在实验random_split()函数时用到了subset类的实例subset,如果将输出语句

print(subset.__getitem__(i)[1])改为
print(subset.__getitem__(i))

则会导致输出结果变成:

[code]<torch.utils.data.dataset.Subset object at 0x000002446FF88630>
<torch.utils.data.dataset.Subset object at 0x0000024403158DD8>
<torch.utils.data.dataset.Subset object at 0x0000024403158E10>
<torch.utils.data.dataset.Subset object at 0x0000024403158D68>
<torch.utils.data.dataset.Subset object at 0x000002440530A160>

没错,是一堆地址,说明subset并没有直接存储分好的子数据集,而是保存了母数据集以及子集的索引,我们__getitem__()得到的是"保存索引的数据集",代码表示就是:

[code]    def __getitem__(self, idx):
return self.dataset[self.indices[idx]]

感觉动手实践一下豁然开朗了!

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: