Source code for crumpets.torch.utils

from collections import OrderedDict

import torch
from torch import nn

from ..presets import IMAGENET_MEAN
from ..presets import IMAGENET_STD


[docs]def save(path, iteration, model, optimizer, **kwargs): model_state = None optimizer_state = None if model is not None: model_state = model.state_dict() if optimizer is not None: optimizer_state = optimizer.state_dict() torch.save( dict(iteration=iteration, model_state=model_state, optimizer_state=optimizer_state, **kwargs), path )
[docs]def resume(path, model, optimizer): """ Given parameters, extracts a training state, i.e. initializes a network and optimizer. :param path: path to a pytorch snapshot (including model and optimizer states) :param model: a network architecture for that the extracted weights are applied to :param optimizer: an optimizer for which the extracted optimizer parameters are applied to :return: the loaded snapshot """ snapshot = torch.load(path) model_state = snapshot.pop('model_state', None) optimizer_state = snapshot.pop('optimizer_state', None) if model is not None and model_state is not None: model.load_state_dict(model_state) if optimizer is not None and optimizer_state is not None: optimizer.load_state_dict(optimizer_state) return snapshot
[docs]def other_type(s): if isinstance(s, str): return s.encode('utf-8') elif isinstance(s, bytes): return s.decode('utf-8')
[docs]def try_dicts(k, *ds): for d in ds: v = d.get(k) if v is not None: return v raise KeyError(k)
[docs]def try_types(k, *ds): try: return try_dicts(k, *ds) except KeyError: return try_dicts(other_type(k), *ds)
[docs]def filter_state(own_state, state_dict): return OrderedDict((k, try_types(k, state_dict, own_state)) for k in own_state)
[docs]class Normalize(nn.Module): def __init__(self, module, mean=IMAGENET_MEAN, std=IMAGENET_STD, grad=False): super(Normalize, self).__init__() self.module = module self.register_buffer('mean', torch.tensor(mean).view(1, -1, 1, 1)) invstd = [1.0 / v for v in std] self.register_buffer('invstd', torch.tensor(invstd).view(1, -1, 1, 1)) self.grad = grad
[docs] def forward(self, x): if self.grad: x = x.float() x = x.sub(self.mean).mul(self.invstd) else: with torch.no_grad(): x = x.float() x.sub_(self.mean).mul_(self.invstd) return x
[docs]class Unpacker(nn.Module): def __init__(self, module, input_key='image', output_key='output'): super(Unpacker, self).__init__() self.input_key = input_key self.output_key = output_key self.module = module
[docs] def forward(self, sample, *_, **__): x = sample[self.input_key] x.data = x.data.float() x = self.module.forward(x) sample[self.output_key] = x return sample