diff --git a/data/__init__.py b/data/__init__.py index 418d064a..9c1f1f57 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -3,3 +3,4 @@ from data.config import resolve_data_config from data.dataset import Dataset from data.transforms import * from data.loader import create_loader +from data.mixup import mixup_target, FastCollateMixup diff --git a/data/loader.py b/data/loader.py index 2e36c800..21c0313a 100644 --- a/data/loader.py +++ b/data/loader.py @@ -1,6 +1,7 @@ import torch.utils.data from data.transforms import * from data.distributed_sampler import OrderedDistributedSampler +from data.mixup import FastCollateMixup def fast_collate(batch): @@ -60,6 +61,18 @@ class PrefetchLoader: def sampler(self): return self.loader.sampler + @property + def mixup_enabled(self): + if isinstance(self.loader.collate_fn, FastCollateMixup): + return self.loader.collate_fn.mixup_enabled + else: + return False + + @mixup_enabled.setter + def mixup_enabled(self, x): + if isinstance(self.loader.collate_fn, FastCollateMixup): + self.loader.collate_fn.mixup_enabled = x + def create_loader( dataset, @@ -75,6 +88,7 @@ def create_loader( num_workers=1, distributed=False, crop_pct=None, + collate_fn=None, ): if isinstance(input_size, tuple): img_size = input_size[-2:] @@ -108,13 +122,16 @@ def create_loader( # of samples per-process, will slightly alter validation results sampler = OrderedDistributedSampler(dataset) + if collate_fn is None: + collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate + loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=sampler is None and is_training, num_workers=num_workers, sampler=sampler, - collate_fn=fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate, + collate_fn=collate_fn, drop_last=is_training, ) if use_prefetcher: diff --git a/data/mixup.py b/data/mixup.py new file mode 100644 index 00000000..83d51ccb --- /dev/null +++ b/data/mixup.py @@ -0,0 +1,42 @@ +import numpy as np +import torch + + +def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): + x = x.long().view(-1, 1) + return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) + + +def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): + off_value = smoothing / num_classes + on_value = 1. - smoothing + off_value + y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) + y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) + return lam*y1 + (1. - lam)*y2 + + +class FastCollateMixup: + + def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000): + self.mixup_alpha = mixup_alpha + self.label_smoothing = label_smoothing + self.num_classes = num_classes + self.mixup_enabled = True + + def __call__(self, batch): + batch_size = len(batch) + lam = 1. + if self.mixup_enabled: + lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) + + target = torch.tensor([b[1] for b in batch], dtype=torch.int64) + target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') + + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + mixed = batch[i][0].astype(np.float32) * lam + \ + batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam) + np.round(mixed, out=mixed) + tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) + + return tensor, target diff --git a/data/transforms.py b/data/transforms.py index 01141086..e777fbca 100644 --- a/data/transforms.py +++ b/data/transforms.py @@ -159,7 +159,7 @@ def transforms_imagenet_train( color_jitter=(0.4, 0.4, 0.4), interpolation='random', random_erasing=0.4, - random_erasing_pp=True, + random_erasing_mode='const', use_prefetcher=False, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD @@ -183,7 +183,7 @@ def transforms_imagenet_train( std=torch.tensor(std)) ] if random_erasing > 0.: - tfl.append(RandomErasing(random_erasing, per_pixel=random_erasing_pp, device='cpu')) + tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu')) return transforms.Compose(tfl) diff --git a/train.py b/train.py index 644db18b..63003763 100644 --- a/train.py +++ b/train.py @@ -10,7 +10,7 @@ try: except ImportError: has_apex = False -from data import Dataset, create_loader, resolve_data_config +from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target from models import create_model, resume_checkpoint from utils import * from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy @@ -66,9 +66,9 @@ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RA parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', help='LR scheduler (default: "step"') parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', - help='Dropout rate (default: 0.1)') -parser.add_argument('--reprob', type=float, default=0.4, metavar='PCT', - help='Random erase prob (default: 0.4)') + help='Dropout rate (default: 0.)') +parser.add_argument('--reprob', type=float, default=0., metavar='PCT', + help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='const', help='Random erase mode (default: "const")') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', @@ -109,6 +109,8 @@ parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, help='use NVIDIA amp for mixed precision training') +parser.add_argument('--no-prefetcher', action='store_true', default=False, + help='disable fast prefetcher') parser.add_argument('--output', default='', type=str, metavar='PATH', help='path to output folder (default: none, current dir)') parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC', @@ -119,6 +121,7 @@ parser.add_argument("--local_rank", default=0, type=int) def main(): args = parser.parse_args() + args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 @@ -130,6 +133,7 @@ def main(): args.world_size = 1 r = -1 if args.distributed: + args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', @@ -216,12 +220,16 @@ def main(): exit(1) dataset_train = Dataset(train_dir) + collate_fn = None + if args.prefetcher and args.mixup > 0: + collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) + loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, - use_prefetcher=True, + use_prefetcher=args.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'], @@ -229,6 +237,7 @@ def main(): std=data_config['std'], num_workers=args.workers, distributed=args.distributed, + collate_fn=collate_fn, ) eval_dir = os.path.join(args.data, 'validation') @@ -242,7 +251,7 @@ def main(): input_size=data_config['input_size'], batch_size=4 * args.batch_size, is_training=False, - use_prefetcher=True, + use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], @@ -309,6 +318,10 @@ def train_epoch( epoch, model, loader, optimizer, loss_fn, args, lr_scheduler=None, saver=None, output_dir='', use_amp=False): + if args.prefetcher and args.mixup > 0 and loader.mixup_enabled: + if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: + loader.mixup_enabled = False + batch_time_m = AverageMeter() data_time_m = AverageMeter() losses_m = AverageMeter() @@ -321,13 +334,15 @@ def train_epoch( for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx data_time_m.update(time.time() - end) - - if args.mixup > 0.: - lam = 1. - if not args.mixup_off_epoch or epoch < args.mixup_off_epoch: - lam = np.random.beta(args.mixup, args.mixup) - input.mul_(lam).add_(1 - lam, input.flip(0)) - target = mixup_target(target, args.num_classes, lam, args.smoothing) + if not args.prefetcher: + input = input.cuda() + target = target.cuda() + if args.mixup > 0.: + lam = 1. + if not args.mixup_off_epoch or epoch < args.mixup_off_epoch: + lam = np.random.beta(args.mixup, args.mixup) + input.mul_(lam).add_(1 - lam, input.flip(0)) + target = mixup_target(target, args.num_classes, lam, args.smoothing) output = model(input) diff --git a/utils.py b/utils.py index b8d00b06..48936aad 100644 --- a/utils.py +++ b/utils.py @@ -140,19 +140,6 @@ def accuracy(output, target, topk=(1,)): return res -def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): - x = x.long().view(-1, 1) - return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) - - -def mixup_target(target, num_classes, lam=1., smoothing=0.0): - off_value = smoothing / num_classes - on_value = 1. - smoothing + off_value - y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value) - y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value) - return lam*y1 + (1. - lam)*y2 - - def get_outdir(path, *paths, inc=False): outdir = os.path.join(path, *paths) if not os.path.exists(outdir):