"""
Example usage of crumpets to train a custom model on Cifar10. Less complex compared to resnet example, since
less parameters are considered (some are just set to their default value to make the example more intuitive).
Cifar10 can either be processed to be in msgpack format or directly downloaded, using Datadings.
This example is capable of using multiple gpus.
If no datadir is given a default sample of 10 images is used
while the loader is told that there are 2000 images to mimic a real dataset.
"""
from __future__ import print_function, unicode_literals, division
import os.path as pt
import sys
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from datadings.reader import Cycler
from datadings.reader import MsgpackReader
from six import text_type
from torch.backends import cudnn
from torch.optim import SGD
from crumpets.workers import ClassificationWorker
from crumpets.presets import AUGMENTATION_TRAIN
from crumpets.torch.dataloader import TorchTurboDataLoader
from crumpets.torch.loss import CrossEntropyLoss
from crumpets.torch.metrics import AccuracyMetric
from crumpets.torch.policy import PolyPolicy
from crumpets.torch.trainer import Trainer
ROOT = pt.dirname(__file__)
sys.path.insert(0, pt.join(ROOT, '..'))
DEFAULT_SAMPLE = pt.join(ROOT, '..', 'data', 'cifar10_sample')
CIFAR10_MEAN = (0.491 * 255, 0.482 * 255, 0.447 * 255)
[docs]class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
[docs] def forward(self, sample):
x = sample['image']
x = x.float()
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
sample['output'] = x
return sample
[docs]def make_loader(
file,
batch_size,
num_mini_batches,
nworkers,
image_rng=None,
image_params=None,
use_cuda=True,
gpu_augmentation=True,
):
reader = MsgpackReader(file)
nsamples = len(reader) if pt.dirname(file) != DEFAULT_SAMPLE else 2000
cycler = Cycler(reader)
worker = ClassificationWorker(
((3, 32, 32), np.uint8, CIFAR10_MEAN),
((1,), np.int),
image_params=image_params,
image_rng=image_rng,
)
return TorchTurboDataLoader(
cycler.rawiter(), batch_size,
worker, nworkers,
gpu_augmentation=gpu_augmentation,
length=nsamples,
num_mini_batches=num_mini_batches,
device='cuda:0' if use_cuda else 'cpu:0',
)
[docs]def make_policy(epochs, network, lr, momentum):
optimizer = SGD([
{'params': network.parameters(), 'lr': lr},
], momentum=momentum, weight_decay=1e-4)
scheduler = PolyPolicy(optimizer, epochs, 1)
return optimizer, scheduler
[docs]def main(
datadir,
outdir,
batch_size,
epochs,
lr,
cuda=True
):
cudnn.benchmark = True
if cuda:
network = Net().cuda()
else:
network = Net()
train = make_loader(
pt.join(datadir, 'train.msgpack') if datadir else None,
batch_size, 1, 4, use_cuda=cuda,
gpu_augmentation=cuda, image_rng=AUGMENTATION_TRAIN
)
val = make_loader(
pt.join(datadir, 'val.msgpack') if datadir else None,
batch_size, 1, 4, use_cuda=cuda,
gpu_augmentation=cuda
)
optimizer, policy = make_policy(epochs, network, lr, 0.9)
loss = CrossEntropyLoss(target_key='label').cuda() if cuda else CrossEntropyLoss(target_key='label')
trainer = Trainer(
network, optimizer, loss, AccuracyMetric(),
policy, None, train, val, outdir
)
with train:
with val:
trainer.train(epochs, 0)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
'-b', '--batch-size',
default=2,
type=int,
help='number of images in batch',
)
parser.add_argument(
'-e', '--epochs',
default=20,
type=int,
help='number of epochs to train',
)
parser.add_argument(
'--lr', '--learning-rate',
default=0.0001,
type=float,
help='initial learning rate',
)
parser.add_argument(
'--datadir',
type=text_type,
help='directory containing training and validation data in form of train.msgpack and val.msgpack',
)
parser.add_argument(
'--outdir',
default='.',
type=text_type,
help='output directory for snapshots and logs',
)
parser.add_argument(
'--cpu',
action="store_false",
dest="cuda",
help='activate to use cpu only for augmentations, forwarding and backwardings',
)
parser.set_defaults(
cuda=True,
)
args, unknown = parser.parse_known_args()
if args.datadir is None:
args.datadir = DEFAULT_SAMPLE
try:
main(
args.datadir,
args.outdir,
args.batch_size,
args.epochs,
args.lr,
args.cuda,
)
except KeyboardInterrupt:
pass
finally:
print()