Source code for crumpets.dataloader

import uuid
import os
import sys
import random
import socket
from itertools import chain
from math import ceil
from threading import Thread
from sys import stderr
from queue import Queue
from queue import Empty
from queue import Full

import zmq
import msgpack

from .broker import Pipeline
from .shm import DummyBufferManager
from .shm import SharedBufferManager


[docs]def remove_files(files): for f in files: try: os.remove(f) except OSError: pass
[docs]def remove_ipc_handles(handles): remove_files([handle.replace('ipc://', '') for handle in handles])
def _find_free_port(): while True: try: port = random.randrange(49152, 61000) socket.socket(socket.AF_INET, socket.SOCK_STREAM).bind(('127.0.0.1', port)) return port except OSError: pass
[docs]def make_addresses(uid, pipeline, numbers=(('work', 1), ('consume', 1))): if sys.platform == 'win32': addr = 'tcp://localhost:%d' return [[addr % _find_free_port() for _ in range(number)] for t, number in numbers] else: addr = 'ipc://%s-%%s-%s-%%d.ipc' % (pipeline, uid) return [[addr % (t, n) for n in range(number)] for t, number in numbers]
def _check_types(vs, t): for v in vs: if not(v is None or isinstance(v, t)): return False return True
[docs]class Slicer(object): def __init__(self, iterable): if not (hasattr(iterable, '__next__') or hasattr(iterable, 'next')): iterable = iter(iterable) self._iterable = iterable def __getitem__(self, item): start = item.start stop = item.stop step = item.step if not _check_types((start, stop, step), int): raise ValueError('expected int or None, got: [%r:%r:%r]' % (start, stop, step)) step = step or 1 it = self._iterable while start: next(it) start -= 1 if stop: stop -= 1 if stop is None: while 1: first = next(it), yield chain(first, (next(it) for _ in range(step-1))) else: for _ in range(stop // step): first = next(it), yield chain(first, (next(it) for _ in range(step-1))) rem = int(stop / step % 1 * step) if rem > 0: first = next(it), yield chain(first, (next(it) for _ in range(rem-1)))
[docs]class Consumer(object): """ A Consumer retrieves and forward processed samples from workers. :param result_address: address to retrieve processed samples from, workers send their results to it :param control_address: address to retrieve control messages from, such as exceptions raised in other processes :param recv_timeout: time to wait in ms until another receiving attempt is made :param bind: bind addresses instead of connecting to them """ def __init__( self, result_address, control_address, recv_timeout=1000, bind=True, ): self.running = True self.result = None self.control = None self.queue = None self.retriever = None self.result_address = result_address self.control_address = control_address self.recv_timeout = recv_timeout self.bind = bind self.queue_length = 1 self.io_threads = 1 self.buffer_manager = DummyBufferManager()
[docs] def set_buffer_manager(self, buffer_manager): self.buffer_manager = buffer_manager
def _transform(self, data): if data is None: return None return self.buffer_manager.unpack(data)
[docs] def retrieve_data(self): while self.running: try: return self.result.recv(copy=False) except zmq.Again: pass
[docs] def start(self): """ Starts the sample retriever thread and listen on the control stream. """ self.queue = Queue(maxsize=self.queue_length) self.retriever = Thread(target=self._retriever_target) self.retriever.daemon = True self.retriever.start() self._connect_control()
[docs] def stop(self): """ Stops all threads opened by this consumer. """ self.running = False while True: if self.queue is None: return try: self.queue.get(False) except Empty: pass try: self.queue.put(None, timeout=1) break except Full: pass
[docs] def retrieve(self): while self.running: try: try: return msgpack.unpackb(self.control.recv(copy=False), raw=False) except zmq.Again: pass item = self.queue.get() self.queue.task_done() return item except Empty: pass
def _connect(self): ctx = zmq.Context(self.io_threads) # noinspection PyUnresolvedReferences self.result = ctx.socket(zmq.PULL) # noinspection PyUnresolvedReferences self.result.setsockopt(zmq.LINGER, 0) # noinspection PyUnresolvedReferences self.result.setsockopt(zmq.RCVHWM, self.queue_length) # noinspection PyUnresolvedReferences self.result.setsockopt(zmq.RCVTIMEO, self.recv_timeout) if self.bind: result_address = self.result_address.replace( 'tcp://localhost', 'tcp://*' ) self.result.bind(result_address) else: self.result.connect(self.result_address) def _connect_control(self): ctx = zmq.Context(self.io_threads) # noinspection PyUnresolvedReferences self.control = ctx.socket(zmq.SUB) # noinspection PyUnresolvedReferences self.control.setsockopt(zmq.LINGER, 0) # noinspection PyUnresolvedReferences self.control.setsockopt(zmq.RCVHWM, self.queue_length) # noinspection PyUnresolvedReferences self.control.setsockopt(zmq.RCVTIMEO, 0) if self.bind: control_address = self.control_address.replace( 'tcp://localhost', 'tcp://*' ) self.control.bind(control_address) else: self.control.connect(self.control_address) import time time.sleep(1) # noinspection PyUnresolvedReferences self.control.setsockopt(zmq.SUBSCRIBE, bytes('', 'utf-8')) def _retriever_target(self): self._connect() try: while self.running: item = self._transform(self.retrieve_data()) self.queue.put(item) self.queue.join() except KeyboardInterrupt: pass finally: self.stop()
[docs]class TurboDataLoader(object): """ TurboDataLoader provides fast parallel loading and processing of input data. Use :class:`~crumpets.torch.dataloader.TorchTurboDataLoader` for a version supporting gpu and pytorch tensors. Always use the loader inside of a with statement, otherwise workers and consumer won't start and stop. `TurboDataLoader`s are intended to be used as iterators. Each iteration yields the following data structure: .. code-block:: python ( iteration, [ # list with 1 item per mini-batch { ... } # sample_dict ]. ) By default `iteration` starts at 0 and counts the number of batches that the loader has yielded. The list contains as many mini-batches as specified by `num_mini_batches`. Note that the number of samples across all mini-batches is equal to `batch_size`, i.e., `batch_size` must be divisible by `num_mini_batches`. Finally each mini-batch is a dictionary that contains key-value-pairs produced by the workers. E.g., a :class:`~crumpets.workers.ClassificationWorker` produces keys `'image'`, `'label'`, and `'augmentation'`. Image and label are arrays and augmentation contains a list of one dictionary per sample in the batch with parameters used to create said sample. Example usage: .. code-block:: python model = make_some_model() with loader: for epoch in range(epochs): for iteration, mini_batch in loader: for sample in mini_batch: sample = model(sample) images = sample['image'] ... Depending on parameters, the TurboDataLoaders starts several processes, some of which cannot be started with the standard "fork" method that Python uses in *nix systems. This can result in crashing with an obscure error message. Thus loaders need to be guarded against starting in non-main modules, i.e.: .. code-block:: python if __name__ == "__main__": # stuff with loader: # other stuff :param iterable: An iterable providing a sample per iteration. :param batch_size: The amount of samples per batch. :param worker_template: An actual worker instance, determines the kind of processing. Has to inherit crumpets.broker.Worker. :param nworkers: Number of workers processing the samples simultaneously. worker_template is copied to create them. :param length: Specifies the length of the dataset. Defaults to the actual length of iterable (if available). If given differs from default, the number of iterations per epoch is modified accordingly. :param num_mini_batches: Number of mini_batches per batch. :param start_iteration: Start the iteration counter from this number. Useful when resuming training. :param shared_memory: Whether to use shared memory to transfer data from workers. If 0 or `False`, shared memory is disabled. If `True`, `2*nworkers` shared buffers will be used. If any number > 0, that number of buffers will be used. A value of 1 is strongly discouraged to prevent deadlocks. Permanently storing values returned by a loader may also cause deadlocks. """ def __init__(self, iterable, batch_size, worker_template, nworkers, length=None, num_mini_batches=1, start_iteration=0, shared_memory=True): uid = uuid.uuid4() self.worker_addresses, self.consumer_addresses, self.control_addresses = make_addresses( uid, 'torch', numbers=(('work', nworkers), ('consume', 1), ('control', 1)), ) self._addresses = self.worker_addresses + self.consumer_addresses + self.control_addresses if batch_size / num_mini_batches != batch_size // num_mini_batches: print('batch_size %d and num_mini_batches %d don\'t match' % (batch_size, num_mini_batches)) self.num_mini_batches = num_mini_batches self.batch_size = batch_size self.mini_batch_size = batch_size // num_mini_batches self.nworkers = nworkers self.iterations = start_iteration self.epoch_iterations = 0 self.length = 0 if length is None: try: length = len(iterable) except (TypeError, AttributeError): pass if length is not None: self.set_length(length) self.consumer = Consumer( self.consumer_addresses[0], self.control_addresses[0], ) if shared_memory: nbuffers = nworkers*2 if shared_memory is True else shared_memory manager = SharedBufferManager(nbuffers, batch_size, worker_template.buffer_specs) worker_template.set_buffer_manager(manager) self.consumer.set_buffer_manager(manager) self.pipeline = Pipeline( worker_template, nworkers, iterable, self.mini_batch_size, self.worker_addresses, self.consumer_addresses, control_address=self.control_addresses[0], gpu_augmentation=False )
[docs] def set_length(self, length): """ Set the length of enclosed iterable. Modifies epoch_iterations accordingly. :param length: len(iterable) """ self.length = length self.epoch_iterations = int(ceil(length / self.batch_size))
[docs] def set_epoch_iterations(self, iterations): """ Set number of iterations in one epoch. Does not modify length. :param iterations: number of iterations per epoch """ self.epoch_iterations = iterations
[docs] def start(self): """ Start the processing pipeline. """ self.consumer.start() self.pipeline.start()
def __enter__(self): self.start() return self
[docs] def stop(self): """ Stop the processing pipeline. """ if hasattr(self, 'pipeline'): self.pipeline.stop() if hasattr(self, 'consumer'): self.consumer.stop() remove_ipc_handles(self._addresses)
__del__ = stop def __exit__(self, exc_type, exc_val, exc_tb): self.stop() def __consume__(self): """ Generator for (iteration, sample) pairs :return: """ while 1: item = self.consumer.retrieve() if item is None: raise StopIteration() if "is_control_msg" in item and item['is_control_msg']: if "raise_error" in item and item["raise_error"]: self.stop() if "traceback" in item: print("\n", *item['traceback'], file=stderr) raise IOError(item["msg"]) else: raise IOError(item["msg"]) else: print("[CONTROL MSG] {}".format(item["msg"])) yield item def __iter__(self): """ Generator for (iteration, sample) pairs :return: """ gen = Slicer(self.__consume__()) m = self.num_mini_batches if self.epoch_iterations > 0: end = self.epoch_iterations * self.num_mini_batches gen = gen[:end:m] for batch in gen: self.iterations += 1 yield self.iterations, batch