Source code for crumpets.workers

import cv2

from ..broker import BufferWorker
from ..augmentation import decode_image
from ..presets import NO_AUGMENTATION
from ..augmentation import randomize_image


__all__ = [
    'ImageWorker',
    'ClassificationWorker',
    'FCNWorker'
]


def noop(im):
    return im


# noinspection PyUnresolvedReferences
def make_cvt(code):
    return lambda im: cv2.cvtColor(im, code)


# noinspection PyUnresolvedReferences
COLOR_CONVERSIONS = {
    None: noop,
    False: noop,
    '': noop,
    'rgb': noop,
    'RGB': noop,
    'hsv': make_cvt(cv2.COLOR_RGB2HSV_FULL),
    'HSV': make_cvt(cv2.COLOR_RGB2HSV_FULL),
    'hls': make_cvt(cv2.COLOR_RGB2HLS_FULL),
    'HLS': make_cvt(cv2.COLOR_RGB2HLS_FULL),
    'lab': make_cvt(cv2.COLOR_RGB2LAB),
    'LAB': make_cvt(cv2.COLOR_RGB2LAB),
    'ycrcb': make_cvt(cv2.COLOR_RGB2YCrCb),
    'YCrCb': make_cvt(cv2.COLOR_RGB2YCrCb),
    'YCRCB': make_cvt(cv2.COLOR_RGB2YCrCb),
    'gray': make_cvt(cv2.COLOR_RGB2GRAY),
    'GRAY': make_cvt(cv2.COLOR_RGB2GRAY),
}


def hwc2chw(im):
    return im.transpose((2, 0, 1))


def chw2hwc(im):
    return im.transpose((1, 2, 0))


def flat(array):
    return tuple(array.flatten().tolist())


[docs]class ImageWorker(BufferWorker): """ Worker for processing images of any kind. :param image: tuple of image information (shape, dtype, fill_value); fill_value is optional, defaults to 0 :param image_params: dict of fixed image parameters; overwrites random augmentation values :param image_rng: RNG object used for image augmentation, see :class:`~crumpets.rng.RNG` and :func:`~crumpets.randomization.randomize_args` :param gpu_augmentation: disables augmentations for which gpu versions are available (:class:`~crumpets.torch.randomizer`) """ def __init__(self, image, image_params=None, image_rng=None, **kwargs): BufferWorker.__init__(self, **kwargs) self.add_buffer('image', image) self.add_params('image', image_params, {}) self.image_rng = image_rng or NO_AUGMENTATION
[docs] def prepare_image(self, im, buffers, params, key): params = dict(params) params.update(self.params[key]) cvt = COLOR_CONVERSIONS[params.pop('colorspace', None)] buffers[key][...] = hwc2chw(cvt(randomize_image( im, buffers[key].shape[1:], background=flat(self.fill_values[key]), **params ))) return params
[docs] def prepare(self, sample, batch, buffers): im = decode_image(sample['image'], self.params['image'].get('color', True)) params = self.image_rng(im, buffers['image']) params['gpu_augmentation'] = self.gpu_augmentation image_params = self.prepare_image(im, buffers, params, 'image') batch['augmentation'].append(image_params) return im, params
[docs]class ClassificationWorker(ImageWorker): """ Worker for processing (Image, Label)-pairs for classification. :param image: tuple of image information (shape, dtype, fill_value); fill_value is optional, defaults to 0 :param label: tuple of label information (shape, dtype, fill_value); fill_value is optional, defaults to 0 :param image_params: dict of fixed image parameters; overwrites random augmentation values :param image_rng: RNG object used for image augmentation, see :class:`~crumpets.rng.RNG` and :func:`~crumpets.randomization.randomize_args` """ def __init__(self, image, label, image_params=None, image_rng=None, **kwargs): ImageWorker.__init__(self, image, image_params, image_rng, **kwargs) self.add_buffer('label', label)
[docs] def prepare(self, sample, batch, buffers): im, params = ImageWorker.prepare(self, sample, batch, buffers) buffers['label'][...] = sample['label'] return im, params
[docs]class FCNWorker(ImageWorker): """ Worker for fully convolutional networks (FCN). Produces `image`-`target_image`-pairs. :param image: tuple of image information (shape, dtype, fill_value); fill_value is optional, defaults to 0 :param target_image: tuple of target image information (shape, dtype, fill_value); fill_value is optional, defaults to 0 :param image_params: dict of fixed image parameters; overwrites random augmentation values :param target_image_params: dict of fixed target image parameters; overwrites random augmentation values :param image_rng: RNG object used for image augmentation, see :class:`~crumpets.rng.RNG` and :func:`~crumpets.randomization.randomize_args` """ def __init__(self, image, target_image, image_params=None, target_image_params=None, image_rng=None, **kwargs): ImageWorker.__init__(self, image, image_params, image_rng, **kwargs) self.add_buffer('target_image', target_image) self.add_params('target_image', target_image_params, {})
[docs] def prepare(self, sample, batch, buffers): im, params = ImageWorker.prepare(self, sample, batch, buffers) self.prepare_image(im, buffers, params, 'target_image')