您的位置:首页 > 其它

「Deep Learning」理解Pytorch中的「torch.utils.data」

2017-12-09 19:27 1471 查看
Sina Weibo:小锋子Shawn

Tencent E-mail:403568338@qq.com
http://blog.csdn.net/dgyuanshaofeng/article/details/78761026
四、DataLoader

    文档在此torch.utils.data.DataLoader

    源代码在此Source code for torch.utils.data.dataloader

import torch
import torch.multiprocessing as multiprocessing
from .sampler import SequentialSampler, RandomSampler, BatchSampler
import collections
import re
import sys
import traceback
import threading
from torch._six import string_classes

if sys.version_info[0] == 2:
import Queue as queue
else:
import queue

_use_shared_memory = False
"""Whether to use shared memory in default_collate"""

class ExceptionWrapper(object):
"Wraps an exception plus traceback to communicate across threads"

def __init__(self, exc_info):
self.exc_type = exc_info[0]
self.exc_msg = "".join(traceback.format_exception(*exc_info))

def _worker_loop(dataset, index_queue, data_queue, collate_fn):
global _use_shared_memory
_use_shared_memory = True

torch.set_num_threads(1)
while True:
r = index_queue.get()
if r is None:
data_queue.put(None)
break
idx, batch_indices = r
try:
samples = collate_fn([dataset[i] for i in batch_indices])
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))

def _pin_memory_loop(in_queue, out_queue, done_event):
while True:
try:
r = in_queue.get()
except Exception:
if done_event.is_set():
return
raise
if r is None:
break
if isinstance(r[1], ExceptionWrapper):
out_queue.put(r)
continue
idx, batch = r
try:
batch = pin_memory_batch(batch)
except Exception:
out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
out_queue.put((idx, batch))

numpy_type_map = {
'float64': torch.DoubleTensor,
'float32': torch.FloatTensor,
'float16': torch.HalfTensor,
'int64': torch.LongTensor,
'int32': torch.IntTensor,
'int16': torch.ShortTensor,
'int8': torch.CharTensor,
'uint8': torch.ByteTensor,
}

def default_collate(batch):
"Puts each data field into a tensor with outer dimension batch size"

error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if torch.is_tensor(batch[0]):
out = None
if _use_shared_memory:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if re.search('[SaUO]', elem.dtype.str) is not None:
raise TypeError(error_msg.format(elem.dtype))

return torch.stack([torch.from_numpy(b) for b in batch], 0)
if elem.shape == ():  # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], int):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], collections.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]

raise TypeError((error_msg.format(type(batch[0]))))

def pin_memory_batch(batch):
if torch.is_tensor(batch):
return batch.pin_memory()
elif isinstance(batch, string_classes):
return batch
elif isinstance(batch, collections.Mapping):
return {k: pin_memory_batch(sample) for k, sample in batch.items()}
elif isinstance(batch, collections.Sequence):
return [pin_memory_batch(sample) for sample in batch]
else:
return batch

class DataLoaderIter(object):
"Iterates once over the DataLoader's dataset, as specified by the sampler"

def __init__(self, loader):
self.dataset = loader.dataset
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory
self.done_event = threading.Event()

self.sample_iter = iter(self.batch_sampler)

if self.num_workers > 0:
self.index_queue = multiprocessing.SimpleQueue()
self.data_queue = multiprocessing.SimpleQueue()
self.batches_outstanding = 0
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}

self.workers = [
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn))
for _ in range(self.num_workers)]

for w in self.workers:
w.daemon = True  # ensure that the worker exits on process exit
w.start()

if self.pin_memory:
in_data = self.data_queue
self.data_queue = queue.Queue()
self.pin_thread = threading.Thread(
target=_pin_memory_loop,
args=(in_data, self.data_queue, self.done_event))
self.pin_thread.daemon = True
self.pin_thread.start()

# prime the prefetch loop
for _ in range(2 * self.num_workers):
self._put_indices()

def __len__(self):
return len(self.batch_sampler)

def __next__(self):
if self.num_workers == 0:  # same-process loading
indices = next(self.sample_iter)  # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch

# check if the next sample has already been generated
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch)

if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration

while True:
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch = self.data_queue.get()
self.batches_outstanding -= 1
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch
continue
return self._process_next_batch(batch)

next = __next__  # Python 2 compatibility

def __iter__(self):
return self

def _put_indices(self):
assert self.batches_outstanding < 2 * self.num_workers
indices = next(self.sample_iter, None)
if indices is None:
return
self.index_queue.put((self.send_idx, indices))
self.batches_outstanding += 1
self.send_idx += 1

def _process_next_batch(self, batch):
self.rcvd_idx += 1
self._put_indices()
if isinstance(batch, ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
return batch

def __getstate__(self):
# TODO: add limited pickling support for sharing an iterator
# across multiple threads for HOGWILD.
# Probably the best way to do this is by moving the sample pushing
# to a separate thread and then just sharing the data queue
# but signalling the end is tricky without a non-blocking API
raise NotImplementedError("DataLoaderIterator cannot be pickled")

def _shutdown_workers(self):
if not self.shutdown:
self.shutdown = True
self.done_event.set()
for _ in self.workers:
self.index_queue.put(None)

def __del__(self):
if self.num_workers > 0:
self._shutdown_workers()

[docs]class DataLoader(object):
"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.

Arguments:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: 1).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: False).
sampler (Sampler, optional): defines the strategy to draw samples from
the dataset. If specified, ``shuffle`` must be False.
batch_sampler (Sampler, optional): like sampler, but returns a batch of
indices at a time. Mutually exclusive with batch_size, shuffle,
sampler, and drop_last.
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means that the data will be loaded in the main process
(default: 0)
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
into CUDA pinned memory before returning them.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: False)
"""

def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last

if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last')

if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle')

if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)

self.sampler = sampler
self.batch_sampler = batch_sampler

def __iter__(self):
return DataLoaderIter(self)

def __len__(self):
return len(self.batch_sampler)
    文档解释。翻译为“数据载入器,结合数据集和采样器,并提供单进程或多进程迭代器”。shuffle默认为false,所以通常要打开它,在学习每一轮数据集时,对数据进行洗牌。如果要用sampler,那么要关shuffle。num_workers指进行数据载入所用的子进程的数量,默认为0,则由主进程进行数据载入,可设置为2,使用多线程。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息