Mixup implemention in progress

* initial impl w/ label smoothing converging, but needs more testing
pull/1/head
Ross Wightman 5 years ago
parent c3fbdd4655
commit fee607edf6

@ -29,7 +29,7 @@ I've included a few of my favourite models, but this is not an exhaustive collec
* PNasNet (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch))
* DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene)
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
* My generic MobileNet (GenMobileNet) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable, InvertedResidual, etc blocks
* Generic MobileNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks
* MNASNet B1, A1 (Squeeze-Excite), and Small
* MobileNet-V1
* MobileNet-V2
@ -49,7 +49,8 @@ Several (less common) features that I often utilize in my projects are included.
* PyTorch w/ single GPU single process (AMP optional)
* A dynamic global pool implementation that allows selecting from average pooling, max pooling, average + max, or concat([average, max]) at model creation. All global pooling is adaptive average by default and compatible with pretrained weights.
* A 'Test Time Pool' wrapper that can wrap any of the included models and usually provide improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs)
* Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Smoothed Softmax, etc)
* Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Label Smoothing, etc)
* Mixup (as in https://arxiv.org/abs/1710.09412) - currently implementing/testing
* An inference script that dumps output to CSV is provided as an example
### Custom Weights

@ -1 +1 @@
from loss.cross_entropy import LabelSmoothingCrossEntropy
from loss.cross_entropy import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy

@ -1,3 +1,4 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -24,3 +25,12 @@ class LabelSmoothingCrossEntropy(nn.Module):
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
class SparseLabelCrossEntropy(nn.Module):
def __init__(self):
super(SparseLabelCrossEntropy, self).__init__()
def forward(self, x, target):
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
return loss.mean()

@ -13,7 +13,7 @@ except ImportError:
from data import *
from models import create_model, resume_checkpoint
from utils import *
from loss import LabelSmoothingCrossEntropy
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
from optim import create_optimizer
from scheduler import create_scheduler
@ -79,6 +79,10 @@ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.0001,
help='weight decay (default: 0.0001)')
parser.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='label smoothing (default: 0.1)')
parser.add_argument('--bn-tf', action='store_true', default=False,
@ -246,7 +250,11 @@ def main():
distributed=args.distributed,
)
if args.smoothing:
if args.mixup > 0.:
# smoothing is handled with mixup label transform
train_loss_fn = SparseLabelCrossEntropy().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
else:
@ -314,6 +322,13 @@ def train_epoch(
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
if args.mixup > 0.:
lam = 1.
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
lam = np.random.beta(args.mixup, args.mixup)
input.mul_(lam).add_(1 - lam, input.flip(0))
target = mixup_target(target, args.num_classes, lam, args.smoothing)
output = model(input)
loss = loss_fn(output, target)

@ -5,6 +5,7 @@ import shutil
import glob
import csv
import operator
import numpy as np
from collections import OrderedDict
@ -139,6 +140,19 @@ def accuracy(output, target, topk=(1,)):
return res
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
x = x.long().view(-1, 1)
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
off_value = smoothing / num_classes
on_value = 1. - smoothing + off_value
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value)
return lam*y1 + (1. - lam)*y2
def get_outdir(path, *paths, inc=False):
outdir = os.path.join(path, *paths)
if not os.path.exists(outdir):

Loading…
Cancel
Save