from operator import mul
from functools import reduce
from math import ceil
from ctypes import c_uint64
import multiprocessing as mp
from multiprocessing.sharedctypes import RawArray
import weakref
import warnings
import numpy as np
from msgpack import Packer
from msgpack import packb
from msgpack import unpackb
from msgpack import ExtType
from msgpack_numpy import encode as _encode_numpy
from msgpack_numpy import decode as _decode_numpy
__all__ = [
'shared_array',
'DummyBufferManager',
'SharedBufferManager',
]
def make_packer(default=None):
return PicklablePacker(use_bin_type=True, default=default)
class PicklablePacker(Packer):
def __init__(self, *args, **kwargs):
Packer.__init__(self, *args, **kwargs)
self.default_fun = kwargs.get('default')
def __reduce__(self):
return make_packer, (self.default_fun,)
[docs]def shared_array(shape, dtype=np.float32):
"""
Create a numpy array that resides in shared memory.
Memory is aligned to 8 bytes.
:param shape:
array shape
:param dtype:
numpy dtype
:return: np.ndarray
"""
size = reduce(mul, shape, 1)
itemsize = np.dtype(dtype).itemsize
alloc = RawArray(c_uint64, int(ceil(size * itemsize / 8)))
return np.frombuffer(alloc, dtype, size).reshape(shape)
EXT_SHARED = ord('s')
[docs]class DummyBufferManager(object):
"""
Dummy replacement for SharedBufferManager.
Supports pack and unpack, but not next methods.
"""
def __init__(self):
self._packer = make_packer()
[docs] def next(self):
raise NotImplementedError('DummyBufferManager does not support next')
[docs] def pack(self, obj):
"""
Pack an object using msgpack.
:param obj: object to pack
:return: msgpack message bytes
"""
return self._packer.pack(obj)
[docs] def unpack(self, data):
"""
Unpack an msgpack message.
:param data: msgpack message bytes
:return: packed objects
"""
return unpackb(data, object_hook=_decode_numpy, raw=False)
[docs]class SharedBufferManager(object):
"""
SharedBufferManager allows transparent sharing of memory between processes.
On creation the specified number of shared memory buffers are created
according to batch size and buffer specs.
`next` returns dict of numpy arrays
that point to a set of shared memory buffers.
`next` blocks until as set of buffers becomes available.
If more than one buffer spec is given, next will always return one buffer
for each spec and will only reuse a set of buffers when none of them are
in use.
pack serializes an arbitrary python object to msgpack format.
It detects shared buffers and replaces them with a "pointer"
as extension type `EXT_SHARED`.
This makes packing fast and independent of array size.
unpack detects "pointer" and replaces them with the shared buffer.
Usage:
* Sender calls next to get a set ob available buffers.
* Sender modifies buffers, calls pack and sends message to receiver.
* Receiver receives the message and calls unpack.
* Receiver uses the unpacked arrays and ensures that they are deleted
at some point, either by going out of scope or explicitly deleting them.
Storing shared buffers permanently may cause deadlocks.
"""
def __init__(self, num_buffers, batch_size, buffer_specs, _queueclass=mp.Queue):
if num_buffers == 1:
warnings.warn('[SharedBufferManager] num_buffers=1, this may produce deadlocks')
self.batch_size = batch_size
self.num_buffers = num_buffers
self.buffer_specs = {
k: ((batch_size,) + tuple(spec[0]), spec[1])
for k, spec in buffer_specs.items()
}
self._buffer_sets = self._create_buffers()
self._in_use = [{} for _ in self._buffer_sets]
# self._buffers = {id(o): o for d in self._buffer_sets for o in d.values()}
self._alias = {}
self._available = _queueclass()
for i, _ in enumerate(self._buffer_sets):
self._available.put(i)
self._packer = make_packer(self._encode)
def _create_buffers(self):
return [
{k: (i, spec, shared_array(spec[0], spec[1]))
for k, spec in self.buffer_specs.items()}
for i in range(self.num_buffers)
]
def _signal_done(self, i, k):
self._in_use[i].pop(k)
if not self._in_use[i]:
self._available.put(i)
def _encode(self, obj):
try:
i = self._alias[id(obj)]
return ExtType(EXT_SHARED, packb(i))
except ValueError:
return _encode_numpy(obj)
@staticmethod
def _create_alias_decode(obj):
return obj.reshape(obj.shape)
def _decode(self, code, data):
if code == EXT_SHARED:
i, k = unpackb(data, raw=False)
i, spec, alloc = self._buffer_sets[i][k]
self._in_use[i][k] = True
array = self._create_alias_decode(alloc)
weakref.finalize(array, self._signal_done, i, k)
return array
else:
return ExtType(code, data)
[docs] def close(self):
"""
Close the queue and unblock any processes waiting on next.
"""
self._available.put(None)
self._available.close()
@staticmethod
def _create_alias_next(obj):
return obj.reshape(obj.shape)
[docs] def next(self):
i = self._available.get()
if i is None:
self.close()
return
allocs = self._buffer_sets[i]
buffers = {k: self._create_alias_next(alloc)
for k, (_, spec, alloc) in allocs.items()}
for k in allocs:
self._alias[id(buffers[k])] = i, k
return buffers
[docs] def pack(self, obj):
"""
Pack an object using msgpack.
Any shared object are replaced by references.
:param obj: object to pack
:return: msgpack message bytes
"""
return self._packer.pack(obj)
[docs] def unpack(self, data):
"""
Unpack an msgpack message.
Any shared object references are replaced with the object.
:param data: msgpack message bytes
:return: packed objects
"""
return unpackb(data, object_hook=_decode_numpy,
ext_hook=self._decode, raw=False)