crumpets.torch.trainer module

class crumpets.torch.trainer.Trainer(network, optimizer, loss, metric, train_policy, val_policy, train_iter, val_iter, outdir, val_loss=None, val_metric=None, snapshot_interval=1, quiet=False)[source]

Bases: object

The Trainer can be used to train a given network. It alternately trains one epoch and validates the resulting net one epoch. Given loss is evaluated each batch, gradients are computed and optimizer used to updated weights. The loss is also passed to the policy, which might update the learning rate. Useful information about the training flow is regularly printed to the console, including an estimated time of arrival. Loss, metric and snapshots per epoch are also logged in outdir, for later investigation. outdir is created if either quiet is False or snapshot_interval > 0.

Parameters
  • network – Some network that is to be trained. If multiple gpus are used (i.e. multiple devices passed to the data loader) a ParallelApply module has to be wrapped around.

  • optimizer – some torch optimzer, e.g. SGD or ADAM, given the network’s parameters.

  • loss – some loss function, e.g. CEL or MSE. Make sure to use crumpets.torch.loss or implement your own ones, but do not use torch losses directly, since they are not capable of handling crumpets sample style (i.e dictionaries).

  • metric – some metric to further measure network’s quality. Similar to losses, use crumpets.torch.metrics

  • train_policy – some policy to maintain learning rates and such, in torch usually called lr_schedulers. After each iteration it, given the current loss, updates learning rates and potentially other hyperparameters.

  • val_policy – same as train_policy, but updates after validation epoch.

  • train_iter – iterator for receiving training samples, usually this means a TorchTurboDataLoader instance.

  • val_iter – same as train_iter, but for retrieving validation samples.

  • outdir – Output directory for logfiles and snapshots. Is created including all parent directories if it does not exist.

  • val_loss – same as loss, but applied during validation. Default is None, which results in using loss again for validation.

  • val_metric – same as metric, but applied during validation. Default is None, which results in using metric again for validation.

  • snapshot_interval – Number of epochs between snapshots. Set to 0 or None to disable snapshots. Default is 1, which means taking a snapshot after every epoch.

  • quiet – If True, trainer will not print to console and will not attempt to create a logfile.

add_hook(name, fun)[source]

Add a function hook for the given event. Function must accept trainer state dictionary as first positional argument the current, as well as further keyword arguments depending on the type of hook.

The following events are available during training:

  • ‘train_begin’: run at the beginning of a training epoch

  • ‘train_end’: run after a training epoch has ended

  • ‘train_pre_forward’: run before the forward step; receives kwarg sample

  • ‘train_forward’: run after the forward step; receives kwargs metric, loss, and output

  • ‘train_backward’: run after the backward step; receives kwargs metric, loss, and output

During validation the following hooks are available:

  • ‘val_begin’: run at the beginning of a training epoch

  • ‘val_end’: run after a training epoch has ended

  • ‘val_pre_forward’: run before the forward step; receives kwarg sample

  • ‘val_forward’: run after the forward step; receives kwargs metric, loss, and output

Parameters
  • name – The event name. See above for available hook names and when they are executed.

  • fun – A function that is to be invoked when given event occurs. See above for method signature.

print_info(epoch)[source]

prints and logs current learning rates as well as the epoch.

Parameters

epoch – the current epoch.

remove_hook(name, fun)[source]

Remove the function hook with the given name.

Parameters
  • name – type of hook to remove

  • fun – hook function object to remove

Returns

snapshot(epoch)[source]

stores snapshot of current model (including optimizer state), uses epoch for naming convention (but does always store current model).

Parameters

epoch – epoch for naming output file

train(num_epochs, start_epoch=0)[source]

starts the training, logs loss and metrics in logging file and prints progress in the console, including an ETA. Also stores snapshots of current model each epoch.

Parameters
  • num_epochs – number of epochs to train

  • start_epoch – the first epoch, default to 0. Can be set higher for finetuning, etc.

train_epoch()[source]

trains one epoch, is invoked by train function. Usually not necessary to be called outside.

Returns

train metric result

validate_epoch(epoch)[source]

Validate once. Invoked by train function. Usually not necessary to be called outside.

Returns

val metric result