import warnings
import torch
from torch.nn import Sequential as Identity
from . import is_cpu_only
from ..dataloader import Consumer
from ..dataloader import TurboDataLoader
from .shm import DummyTensorManager
from .shm import SharedTensorManager
from ..broker import Pipeline
from .randomizer import Randomizer
__all__ = ['TorchTurboDataLoader']
class TorchConsumer(Consumer):
"""
Consumer to retrieve and forward processed samples from workers.
:param result_address:
address to retrieve processed samples from, workers send their results to it
:param control_address:
address to retrieve control messages from, such as exceptions raised in other processes
:param recv_timeout:
time to wait in ms until another receiving attempt is made
:param bind:
bind addresses instead of connecting to them
:param device:
string or torch device that tensors are copied to
:param gpu_augmentation:
uses :class:`~crumpets.torch.randomizer` to gpu augment retrieved samples
"""
def __init__(
self,
result_address,
control_address,
recv_timeout=1000,
bind=True,
device='cuda:0',
gpu_augmentation=False
):
if gpu_augmentation and is_cpu_only(device):
warnings.warn('using gpu_augmentation=True with device %r' % device)
super(TorchConsumer, self).__init__(
result_address,
control_address,
recv_timeout,
bind,
)
device = torch.device(device)
self.randomizer = Randomizer().to(device) if gpu_augmentation else Identity()
self.set_buffer_manager(DummyTensorManager(device))
def _transform(self, data):
if data is None:
return None
unpacked = super(TorchConsumer, self)._transform(data)
return self.randomizer(unpacked)
[docs]class TorchTurboDataLoader(TurboDataLoader):
"""
TorchTurboDataLoader is a subclass of
:class:`~crumpets.dataloader.TurboDataLoader`
intended for use with the Pytorch framework.
It produces torch tensors instead of numpy arrays.
See :class:`~crumpets.dataloader.TurboDataLoader`
for more details on its operation.
:param iterable:
An iterable providing a sample per iteration.
:param batch_size:
The amount of samples per batch.
:param worker_template:
An actual worker instance, determines the kind of processing.
Has to inherit crumpets.broker.Worker.
:param nworkers:
Number of workers processing the samples simultaneously.
worker_template is copied to create them.
:param length:
Specifies the length of the dataset.
Defaults to the actual length of iterable (if available).
If given differs from default,
the number of iterations per epoch is modified accordingly.
:param num_mini_batches:
Number of mini_batches per batch.
:param start_iteration:
Start the iteration counter from this number.
Useful when resuming training.
:param shared_memory:
Whether to use shared memory to transfer data from workers.
If 0 or `False`, shared memory is disabled.
If `True`, `2*nworkers` shared buffers will be used.
If any number > 0, that number of buffers will be used.
A value of 1 is strongly discouraged to prevent deadlocks.
Permanently storing values returned by a loader may also
cause deadlocks.
:param device:
torch device to use,
Defaults to 'cuda:0'.
:param gpu_augmentation:
Use a :class:`~crumpets.torch.randomizer.Randomizer`
to calculate certain data augmentation operations on GPU.
This disables said operations on the CPU side.
"""
def __init__(self, iterable, batch_size, worker_template, nworkers,
length=None, num_mini_batches=1, start_iteration=0,
device='cuda:0', gpu_augmentation=False,
shared_memory=True):
super(TorchTurboDataLoader, self).__init__(
iterable, batch_size, worker_template, nworkers, length, num_mini_batches,
start_iteration, shared_memory,
)
self.consumer = TorchConsumer(
self.consumer_addresses[0],
self.control_addresses[0],
device=device,
gpu_augmentation=gpu_augmentation
)
if shared_memory:
nbuffers = nworkers*2 if shared_memory is True else shared_memory
manager = SharedTensorManager(
nbuffers,
batch_size,
worker_template.buffer_specs,
device=device
)
worker_template.set_buffer_manager(manager)
self.consumer.set_buffer_manager(manager)
self.pipeline = Pipeline(
worker_template,
nworkers,
iterable,
self.mini_batch_size,
self.worker_addresses,
self.consumer_addresses,
control_address=self.control_addresses[0],
gpu_augmentation=gpu_augmentation
)