diff --git a/README.md b/README.md index 1a4a9eb7..29a6c1ae 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ I've included a few of my favourite models, but this is not an exhaustive collec * PNasNet (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch)) * DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene) * DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107 -* My generic MobileNet (GenMobileNet) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable, InvertedResidual, etc blocks +* Generic MobileNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks * MNASNet B1, A1 (Squeeze-Excite), and Small * MobileNet-V1 * MobileNet-V2 @@ -49,7 +49,8 @@ Several (less common) features that I often utilize in my projects are included. * PyTorch w/ single GPU single process (AMP optional) * A dynamic global pool implementation that allows selecting from average pooling, max pooling, average + max, or concat([average, max]) at model creation. All global pooling is adaptive average by default and compatible with pretrained weights. * A 'Test Time Pool' wrapper that can wrap any of the included models and usually provide improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs) -* Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Smoothed Softmax, etc) +* Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Label Smoothing, etc) +* Mixup (as in https://arxiv.org/abs/1710.09412) - currently implementing/testing * An inference script that dumps output to CSV is provided as an example ### Custom Weights diff --git a/loss/__init__.py b/loss/__init__.py index 9ad83fb1..6eaa4c76 100644 --- a/loss/__init__.py +++ b/loss/__init__.py @@ -1 +1 @@ -from loss.cross_entropy import LabelSmoothingCrossEntropy \ No newline at end of file +from loss.cross_entropy import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy \ No newline at end of file diff --git a/loss/cross_entropy.py b/loss/cross_entropy.py index db4aaed9..821b1fe3 100644 --- a/loss/cross_entropy.py +++ b/loss/cross_entropy.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn import torch.nn.functional as F @@ -24,3 +25,12 @@ class LabelSmoothingCrossEntropy(nn.Module): loss = self.confidence * nll_loss + self.smoothing * smooth_loss return loss.mean() + +class SparseLabelCrossEntropy(nn.Module): + + def __init__(self): + super(SparseLabelCrossEntropy, self).__init__() + + def forward(self, x, target): + loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) + return loss.mean() diff --git a/train.py b/train.py index 9369e62a..b9ac1891 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ except ImportError: from data import * from models import create_model, resume_checkpoint from utils import * -from loss import LabelSmoothingCrossEntropy +from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy from optim import create_optimizer from scheduler import create_scheduler @@ -79,6 +79,10 @@ parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') parser.add_argument('--weight-decay', type=float, default=0.0001, help='weight decay (default: 0.0001)') +parser.add_argument('--mixup', type=float, default=0.0, + help='mixup alpha, mixup enabled if > 0. (default: 0.)') +parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', + help='turn off mixup after this epoch, disabled if 0 (default: 0)') parser.add_argument('--smoothing', type=float, default=0.1, help='label smoothing (default: 0.1)') parser.add_argument('--bn-tf', action='store_true', default=False, @@ -246,7 +250,11 @@ def main(): distributed=args.distributed, ) - if args.smoothing: + if args.mixup > 0.: + # smoothing is handled with mixup label transform + train_loss_fn = SparseLabelCrossEntropy().cuda() + validate_loss_fn = nn.CrossEntropyLoss().cuda() + elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: @@ -314,6 +322,13 @@ def train_epoch( last_batch = batch_idx == last_idx data_time_m.update(time.time() - end) + if args.mixup > 0.: + lam = 1. + if not args.mixup_off_epoch or epoch < args.mixup_off_epoch: + lam = np.random.beta(args.mixup, args.mixup) + input.mul_(lam).add_(1 - lam, input.flip(0)) + target = mixup_target(target, args.num_classes, lam, args.smoothing) + output = model(input) loss = loss_fn(output, target) diff --git a/utils.py b/utils.py index b19ed32c..b8d00b06 100644 --- a/utils.py +++ b/utils.py @@ -5,6 +5,7 @@ import shutil import glob import csv import operator +import numpy as np from collections import OrderedDict @@ -139,6 +140,19 @@ def accuracy(output, target, topk=(1,)): return res +def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): + x = x.long().view(-1, 1) + return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) + + +def mixup_target(target, num_classes, lam=1., smoothing=0.0): + off_value = smoothing / num_classes + on_value = 1. - smoothing + off_value + y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value) + y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value) + return lam*y1 + (1. - lam)*y2 + + def get_outdir(path, *paths, inc=False): outdir = os.path.join(path, *paths) if not os.path.exists(outdir):