from copy import copy
import random
import warnings
import torch.nn as nn
from . import augmentation_cuda as cudaaugs
[docs]class Randomizer(nn.Module):
"""
Given a network (or in general, some pytorch module), it is wrapped around the nets forward pass.
If the randomizer's forward function is invoked, it first randomizes the image in the sample dictionary.
That means it basically works like :func:`~crumpets.augmentation.randomize_image`,
which is usually applied to the image in one of the workers.
The major difference here is that all augmentations are gpu powered, and thus faster.
Also not all augmentation operations are supported. The randomizer does not rotate or resize.
The values used for augmenting are picked out of the dictionary.
Therefore the sample dictionary must contain these. Usually crumpets worker take care of that.
:param net: some network the randomizer shall be wrapped around
"""
def __init__(self, net=None):
super(Randomizer, self).__init__()
self.net = net
self.use_cuda = False
[docs] def forward(self, sample, *args, **kwargs):
"""
Applies different randomizing augmentations to input images and then
forwards result through net, if given.
:param sample:
dictonary with
{"image": Tensor of shape n,c,h,w,
"augmentation": list of augmentation parameters per image in batch}
:return:
modified dictionary with randomized image and network modified entries
"""
if self.net is not None: # asssumes ParallelApply is used
if 'gpu_augmentation' in sample['augmentation'][0] \
and not sample['augmentation'][0]['gpu_augmentation']:
warnings.warn('gpu_augmentation for randomization is not activated, '
'but Randomizer Module is used! '
'Directly forwarding to net now.')
return self.net.forward(sample) if self.net is not None else sample
if 'gpu_augmentation' in sample['augmentation'][0] \
and not sample['augmentation'][0]['gpu_augmentation']:
warnings.warn('gpu_augmentation for randomization is not activated, '
'but Randomizer Module is used! '
'Directly forwarding to net now.')
return self.net.forward(sample)
im = sample['image']
if len(im.shape) != 4:
raise AttributeError(
'image shape length {} != 4, but 4 is required'.format(im.shape)
)
# randomize the order of operations
order = list(range(3))
random.shuffle(order)
for op in order:
if op == 0:
im = cudaaugs.add_gamma(im, sample['augmentation'])
if op == 1:
if im.shape[1] > 1:
im = cudaaugs.add_noise_rgb(im, sample['augmentation'])
else:
im = cudaaugs.add_noise_other(im, sample['augmentation'])
if op == 2:
im = cudaaugs.add_blur(im, sample['augmentation'])
result = copy(sample)
result['image'] = im
# forward through net
return self.net.forward(result) if self.net is not None else result
[docs] def cuda(self, device_id=None):
super(Randomizer, self).cuda(device_id)
self.use_cuda = True
return self
[docs] def cpu(self):
super(Randomizer, self).cpu()
self.use_cuda = False
return self