from typing import Union
from typing import Iterable
import time
import traceback
from itertools import cycle
from threading import Thread
from abc import ABCMeta
from abc import abstractmethod
from copy import deepcopy
import multiprocessing as mp
from collections import defaultdict
from queue import Queue
from queue import Empty
from queue import Full
import numpy as np
import zmq
from msgpack import packb
from msgpack import unpackb
from msgpack_numpy import encode
from msgpack_numpy import decode
from . import procname
[docs]class ProducerBase(mp.Process, metaclass=ABCMeta):
"""
Abstract base class for producer processes.
Producers are the first stage of the pre-processing pipeline that
load data into memory and supply it to workers.
Implement the yield_requests method to customize its behavior.
:param work_addresses:
List of worker addresses the producer pushes work to;
cycled through for load balancing
:param daemon:
Flag whether this Producer is a daemon process;
see multiprocessing.Process
:param queue_length:
Length of send queue per worker socket
:param io_threads:
Number of IO threads to use; 1 is fine for almost all cases
"""
def __init__(self, work_addresses, daemon=True, queue_length=8, io_threads=1):
mp.Process.__init__(self, name='pyProducer')
self.work_addresses = work_addresses
self.queue_length = queue_length
self.io_threads = io_threads
self.running = mp.Value('b')
self.running.value = True
self.daemon = daemon
[docs] def stop(self):
self.running.value = False
# noinspection PyUnresolvedReferences
[docs] def run(self):
procname.setprocname('pyProducer')
ctx = zmq.Context(io_threads=self.io_threads)
worker = []
for work_address in self.work_addresses:
work = ctx.socket(zmq.PUSH)
work.setsockopt(zmq.SNDHWM, self.queue_length)
work.bind(work_address.replace('tcp://localhost', 'tcp://*'))
worker.append(work)
worker = cycle(worker)
gen = self.yield_requests()
while self.running.value:
try:
next(worker).send_multipart(next(gen))
except (KeyboardInterrupt, StopIteration):
self.stop()
[docs] def yield_requests(self):
raise NotImplementedError('implement yield_requests as generator')
[docs]class Producer(ProducerBase):
"""
Producer implementation that reads sequentially from arbitrary
iterable objects. Items must be a msgpack messages that are
understood by the workers.
:param iterable:
iterable of msgpack messages
:param batch:
batch size for workers
"""
def __init__(self, work_addresses, iterable, batch,
queue_length=8, io_threads=1):
ProducerBase.__init__(self, work_addresses,
queue_length=queue_length,
io_threads=io_threads)
if not hasattr(iterable, '__next__'):
iterable = iter(iterable)
self.iterable = iterable
self._batch = range(batch)
[docs] def yield_requests(self):
request = tuple(next(self.iterable) for _ in self._batch)
while request:
yield request
request = tuple(next(self.iterable) for _ in self._batch)
[docs]class ConsumerBase(metaclass=ABCMeta):
"""
Abstract base class for Consumers, the final pipeline stage.
Implement the _transform method to define subclass behavior.
"""
# noinspection PyUnresolvedReferences
def __init__(
self,
result_address,
recv_timeout=1000,
queue_length=3,
bind=True,
io_threads=1
):
self.running = True
self.ctx = zmq.Context(io_threads)
self.result = self.ctx.socket(zmq.PULL)
self.result.setsockopt(zmq.LINGER, 0)
self.result.setsockopt(zmq.RCVHWM, queue_length)
self.result.setsockopt(zmq.RCVTIMEO, recv_timeout)
if bind:
result_address = result_address.replace('tcp://localhost', 'tcp://*')
self.result.bind(result_address)
else:
self.result.connect(result_address)
[docs] def stop(self):
self.running = False
[docs] def retrieve(self):
return self._transform(self.retrieve_data())
@abstractmethod
def _transform(self, data):
"""
Implement this function to transform raw data received from workers
to a usable Python object.
:param data:
raw bytes from workers
:return:
transformed Python object
"""
pass
[docs] def retrieve_data(self):
while self.running:
try:
return self.result.recv(copy=False)
except zmq.Again:
pass
[docs]class ThreadedConsumerBase(ConsumerBase, metaclass=ABCMeta):
"""
Abstract base class for Consumers that receive and transform data
on a separate thread.
Implement the _transform method to define subclass behavior.
"""
def __init__(
self,
result_address,
recv_timeout=2000,
queue_length=3,
bind=True,
io_threads=1
):
ConsumerBase.__init__(
self,
result_address,
recv_timeout,
queue_length,
bind,
io_threads
)
self.queue = Queue(maxsize=queue_length)
self.retriever = Thread(target=self.__retriever_target)
self.retriever.daemon = True
self.retriever.start()
[docs] def stop(self):
ConsumerBase.stop(self)
while True:
try:
self.queue.get(False)
except Empty:
pass
try:
self.queue.put(None, timeout=1)
break
except Full:
pass
[docs] def retrieve(self):
while self.running:
try:
return self.queue.get(timeout=0.01) or None
except Empty:
pass
def __retriever_target(self):
while self.running:
self.queue.put(self._transform(self.retrieve_data()))
self.queue.put(None)
[docs]class Consumer(ThreadedConsumerBase):
"""
Basic threaded Consumer that receives und unpacks msgpack messages.
"""
def _transform(self, data):
return unpack(data)
[docs]class Proxy(mp.Process):
"""
Utility class that receives and redirects zmq PULL/PUSH streams.
"""
def __init__(
self,
in_address,
out_address,
queue_length=1,
daemon=True,
):
mp.Process.__init__(self, name='pyProxy')
self.in_address = in_address.replace('tcp://localhost', 'tcp://*')
self.out_address = out_address.replace('tcp://localhost', 'tcp://*')
self.queue_length = queue_length
self.daemon = daemon
# noinspection PyUnresolvedReferences
[docs] def run(self):
procname.setprocname('pyProxy')
ctx = zmq.Context(1)
pull = ctx.socket(zmq.PULL)
pull.setsockopt(zmq.RCVHWM, 1)
pull.bind(self.in_address)
push = ctx.socket(zmq.PUSH)
push.setsockopt(zmq.SNDHWM, self.queue_length)
push.bind(self.out_address)
try:
zmq.proxy(pull, push)
except KeyboardInterrupt:
pass
finally:
pull.close()
push.close()
ctx.term()
[docs]class Value(object):
def __init__(self, *_, **__):
self.value = None
[docs]class Worker(metaclass=ABCMeta):
"""
Abstract base class for workers.
Implement the process method to define the behavior of subclasses.
.. note::
set_addresses must be called before starting a worker.
The :class:`~crumpets.dataloader.TurboDataLoader` does this for you.
:param timeout:
zmq socket timeout in milliseconds
:param daemon:
set daemon flag - used in process
:param gpu_augmentation:
set GPU augmentation flag
"""
def __init__(self, timeout=1000, daemon=True, gpu_augmentation=False):
# mp.Process.__init__(self, name='pyWorker')
self.work_address = None
self.result_address = None
self.control_address = None
self.timeout = timeout
self.running = Value() # replaced with multiprocessing.Value
self.running.value = True
self.daemon = daemon
self.gpu_augmentation = gpu_augmentation
[docs] def set_addresses(self, work, result, control):
"""
Set all required zmq addresses.
Required before run can be invoked.
:param work:
address where work is received on
:param result:
results are pushed to this address
:param control:
control message are sent here, e.g.,
exceptions that occurred while processing
"""
self.work_address = work
self.result_address = result
self.control_address = control
[docs] def set_gpu_augmentation(self, val):
"""
Sets the gpu_augmentation flag to given value, true disables
all cpu_augmentations for which a gpu version is available.
Note that this does not directly activate usage of gpu augmentation, as for that
a :class:`~crumpets.torch.randomizer` module is used, which usually
the :class:`~crumpets.dataloader.TurboDataLoader` takes care of.
:param val:
boolean flag
"""
self.gpu_augmentation = val
[docs] def stop(self):
"""
Stops the worker process.
"""
self.running.value = False
# noinspection PyUnresolvedReferences
[docs] def inner(self):
addr = self.work_address, self.result_address, self.control_address
if not all(addr):
raise RuntimeError(
'Some addresses not set. '
'Call set_addresses before start. '
'work: %r, result: %r, control: %r' % addr
)
procname.setprocname('pyWorker')
ctx = zmq.Context()
work = ctx.socket(zmq.PULL)
work.setsockopt(zmq.RCVHWM, 1)
work.setsockopt(zmq.RCVTIMEO, self.timeout)
work.setsockopt(zmq.LINGER, 0)
work.connect(self.work_address)
result = ctx.socket(zmq.PUSH)
result.setsockopt(zmq.SNDHWM, 1)
result.setsockopt(zmq.LINGER, 0)
result.connect(self.result_address)
while self.running.value:
try:
request = work.recv_multipart(copy=False)
if request[0] == b'\x00':
self.stop()
break
for data in self.process(request):
if data:
result.send(data)
except zmq.Again:
pass
[docs] def run(self):
"""
Starts the worker process.
"""
try:
self.inner()
except KeyboardInterrupt:
self.stop()
except Exception as e:
trace = traceback.format_exception(type(e), e, e.__traceback__)
self._send_control_msg(str(e), trace)
self.stop()
raise
[docs] @abstractmethod
def process(self, data):
"""
Implement this method to define worker behavior.
Can return an iterable to create several batches from one input.
This method can return an iterable or define a generator
with the yield keyword.
For instance: :func:`~crumpets.broker.BufferWorker.process`
.
:param data:
multipart zmq message from Producer to process
:return:
iterable of zmq messages to send to Consumer
"""
return []
# noinspection PyUnresolvedReferences
def _send_control_msg(self, msg, trace=None):
if not self.control_address:
raise ValueError("no control address specified, "
"control msg to send was: {}".format(msg))
ctx = zmq.Context()
socket = ctx.socket(zmq.PUB)
socket.connect(self.control_address)
# TODO sync that instead of sleep?
# Dunno if that's possible in zmq - Joachim
time.sleep(1)
socket.send(packb({
"msg": msg,
"is_control_msg": True,
"raise_error": True,
"traceback": trace
}, use_bin_type=True))
# TODO may need to sleep some before the context dies
[docs]def unpack(obj):
return unpackb(obj, object_hook=decode, raw=False)
[docs]def make_fill_value(shape, dtype, fill_value: Union[int, float, Iterable] = 0):
"""
Create a numpy array for a given fill value.
This array can be used to fill any array of the given shape and dtype,
e.g., arr[:] = make_fill_value(arr.shape, arr.dtype, 17) will set all
elements of arr to 17.
Note: An implicit first dimension for the batch size is added.
fill_value can be a scalar or iterable.
Iterables are padded ith unit dimensions until they match the number
of dimensions of the given shape, e.g.:
>>> make_fill_value((3, 224, 224), np.uint8, (1, 2, 3))
array([[[[1]], [[2]], [[3]]]], dtype=uint8)
The resulting fill value array has shape (1, 3, 1, 1).
:param shape:
array shape
:param dtype:
array dtype
:param fill_value:
optional fill value(s)
:return:
fill value array
"""
fill_value = np.asarray(fill_value, dtype)
if len(fill_value.shape) == 0:
shape = (1,) + (1,) * len(shape)
else:
filler_dims = len(shape) - len(fill_value.shape)
shape = (1,) + fill_value.shape + (1,) * filler_dims
# noinspection PyArgumentList
return fill_value.reshape(shape)
[docs]def make_bufferspec(buf):
"""
Turn numpy.ndarray into buffer specification:
:param buf:
np.ndarray or buffer spec
:return:
tuple(shape, dtype, fill_value)
"""
if isinstance(buf, np.ndarray):
shape = buf.shape
# Check for variance in buffer dimensions:
# If dimension is unit, use slice(None)
# If var > 0, slice(None) is equivalent to array[:].
# If var == 0, use index 0 element as representative.
ind = tuple(slice(None) if s == 1 or np.any(buf.var(d) > 0) else 0
for d, s in enumerate(shape))
fill_value = buf[ind]
if np.isscalar(fill_value):
fill_value = fill_value.item()
return shape, buf.dtype, fill_value
else:
return buf[0], buf[1], make_fill_value(*buf)
[docs]def make_buffer(batchsize, shape, dtype, fill_value):
"""
Create an array for a given batch size and buffer spec.
Resulting array has shape = (batchsize,) + shape.
:param batchsize:
size of the first dimension
:param shape:
remaining shape of the array
:param dtype:
numpy dtype of the array
:param fill_value:
array comes pre-filled with this value
:return:
array
"""
buf = np.empty((batchsize,) + tuple(shape), dtype)
buf[:] = fill_value
return buf
[docs]class BufferManager(object):
"""
BufferManager is a compatibility class that replaces the SharedDictManager
for cases where shared memory is not used by the pipeline.
It creates buffers from buffer specs for use with the BufferWorker.
"""
def __init__(self, batch_size, buffer_specs):
self.batch_size = batch_size
self.buffer_specs = buffer_specs
self.buffers = None
[docs] def next(self):
"""
Return the dictionary of buffers as defined by buffer specs.
:return:
buffer dictionary
"""
if self.buffers is None:
if not self.buffer_specs:
return {}
n = self.batch_size
self.buffers = {
k: make_buffer(n, *spec)
for k, spec in self.buffer_specs.items()
}
return self.buffers
[docs] @staticmethod
def pack(obj):
"""
Pack an object using msgpack.
Any shared object are replaced by references.
:param obj:
object to pack
:return:
msgpack message bytes
"""
return packb(obj, use_bin_type=True, default=encode)
[docs] @staticmethod
def unpack(data):
"""
Unpack an msgpack message.
Any shared object references are replaced with the object.
:param data:
msgpack message bytes
:return:
packed objects
"""
return unpackb(data, object_hook=decode, raw=False)
[docs]class BufferWorker(Worker, metaclass=ABCMeta):
"""
Base class for workers that use constant-size buffers.
:param buffer_manager:
Dict of buffer specs (shape, dtype, fill_value).
fill_value is optional and defaults to 0.
It must be either a scalar or iterable of length equal to
the number of channels in the respective image.
:param param_groups:
Dict of fixed parameter dicts.
To be used in conjunction with buffers of the same key.
:param kwargs:
Passed to broker.Worker.
"""
def __init__(self, buffer_manager=None, **kwargs):
Worker.__init__(self, **kwargs)
self.buffer_manager = buffer_manager
self.buffer_specs = {}
self.fill_values = {}
self.params = {}
[docs] def get_buffer_manager(self):
"""
Returns the current buffer manager.
May be None.
:return: `BufferManager` or `SharedBufferManager` object
"""
return self.buffer_manager
[docs] def set_buffer_manager(self, buffer_manager):
"""
Set the buffer manager to be used by this worker.
Can be None, in which case a `BufferManager` will
be created as necessary.
:param buffer_manager:
a `BufferManager` or `SharedBufferManager` object, or None
"""
self.buffer_manager = buffer_manager
[docs] def add_buffer(self, key, buf):
"""
Register a new buffer with the worker.
:param key:
name of the buffer
:param buf:
buffer spec or array to use as template
"""
spec = make_bufferspec(buf)
self.buffer_specs[key] = spec
self.fill_values[key] = spec[2]
[docs] def add_params(self, key, params, default=None):
"""
Add a parameter group to the worker.
:param key:
name of the parameters
:param params:
parameter object, usually dictionary
:param default:
default value to use if params is None
"""
self.params[key] = default if params is None else params
[docs] @abstractmethod
def prepare(self, sample, batch, buffers):
"""
Implement this method to define the behavior of the BufferWorker
subclass. Results must be written to buffers and/or batch object.
:param sample:
individual sample object to process
:param batch:
the object the sample belongs to;
append values to lists as necessary
:param buffers:
output buffers to use for this sample
"""
pass
[docs] def process(self, request):
n = len(request)
if self.buffer_manager is None:
self.buffer_manager = BufferManager(n, self.buffer_specs)
buffers = self.buffer_manager.next()
rows = [{k: buf[i] for k, buf in buffers.items()}
for i in range(n)]
batch = defaultdict(list)
for row, sample in zip(rows, request):
try:
self.prepare(self.buffer_manager.unpack(sample), batch, row)
except ValueError as e:
for k, buf in row.items():
buf[...] = self.fill_values[k]
print('[BufferWorker] cannot prepare sample:', str(e))
# request is smaller than batch, zero remaining rows
if n < len(rows):
for k in buffers:
buffers[k][n:] = self.fill_values[k]
batch.update(buffers)
return self.buffer_manager.pack(batch), # comma to return tuple
[docs]class Dispatcher(object):
"""
The Dispatcher creates worker processes from a worker template,
can starts and stops them and monitor their status.
:param worker_template:
instance of Worker subclass to use as template for workers;
copy.copy is used to create as many objects as needed
:param nworkers:
number of worker processes to start
:param work_addresses:
list of work addresses to use; cycled through
:param result_addresses:
list of result addresses to use; cycles through
:param control_address:
control address workers can send status updates on
:param daemon:
daemon flag for processes, see multiprocessing.Process
:param gpu_augmentation:
bool passed to workers, true disables cpu augmentations
where gpu versions are available in :class:`~crumpets.torch.randomizer`;
if None worker_template.gpu_augmentation is used
"""
def __init__(
self,
worker_template, nworkers,
work_addresses, result_addresses, control_address,
daemon=None,
gpu_augmentation=None
):
work_address = cycle(work_addresses)
result_address = cycle(result_addresses)
try:
buffer_manager = worker_template.get_buffer_manager()
worker_template.set_buffer_manager(None)
except AttributeError:
buffer_manager = None
self.workers = [deepcopy(worker_template) for _ in range(nworkers)]
for worker, work, result \
in zip(self.workers, work_address, result_address):
worker.running = mp.Value('b')
worker.running.value = worker_template.running.value
worker.set_addresses(work, result, control_address)
worker.set_gpu_augmentation(gpu_augmentation or worker.gpu_augmentation)
try:
worker.set_buffer_manager(buffer_manager)
except AttributeError:
pass
try:
worker_template.set_buffer_manager(buffer_manager)
except AttributeError:
pass
self.processes = [mp.Process(target=w.run) for w in self.workers]
for w in self.processes:
w.daemon = daemon or w.daemon
[docs] def start(self):
for proc in self.processes:
proc.start()
[docs] def stop(self):
for worker in self.workers:
worker.stop()
[docs] def active(self):
"""
True if any workers are alive.
"""
return any([worker.is_alive() for worker in self.workers])
[docs] def terminate(self):
for proc in self.processes:
proc.terminate()
[docs]class Pipeline(object):
def __init__(
self,
worker_template, nworkers, iterable, batch_size,
work_addresses, result_addresses,
producer_kwargs=None, control_address=None,
gpu_augmentation=None
):
self.dispatcher = Dispatcher(worker_template, nworkers,
work_addresses, result_addresses, control_address,
gpu_augmentation=gpu_augmentation)
producer_kwargs = producer_kwargs or {}
self.producer = Producer(
work_addresses,
iterable,
batch_size,
**producer_kwargs
)
[docs] def start(self):
self.dispatcher.start()
self.producer.start()
[docs] def stop(self):
try:
self.producer.stop()
self.producer.terminate()
except AttributeError:
pass # producer already terminated/does not exist
try:
self.dispatcher.stop()
self.dispatcher.terminate()
except AttributeError:
pass # dispatcher already terminated/does not exist
__del__ = stop