Mixup and prefetcher improvements

* Do mixup in custom collate fn if prefetcher enabled, reduces performance impact
* Move mixup code to own file
* Add arg to disable prefetcher
* Fix no cuda transfer when prefetcher off
* Random erasing when prefetcher off wasn't changed to match new args, fixed
* Default random erasing to off (prob = 0.) for train
pull/2/head
Ross Wightman 5 years ago
parent 780c0a96a4
commit 4d2056722a

@ -3,3 +3,4 @@ from data.config import resolve_data_config
from data.dataset import Dataset
from data.transforms import *
from data.loader import create_loader
from data.mixup import mixup_target, FastCollateMixup

@ -1,6 +1,7 @@
import torch.utils.data
from data.transforms import *
from data.distributed_sampler import OrderedDistributedSampler
from data.mixup import FastCollateMixup
def fast_collate(batch):
@ -60,6 +61,18 @@ class PrefetchLoader:
def sampler(self):
return self.loader.sampler
@property
def mixup_enabled(self):
if isinstance(self.loader.collate_fn, FastCollateMixup):
return self.loader.collate_fn.mixup_enabled
else:
return False
@mixup_enabled.setter
def mixup_enabled(self, x):
if isinstance(self.loader.collate_fn, FastCollateMixup):
self.loader.collate_fn.mixup_enabled = x
def create_loader(
dataset,
@ -75,6 +88,7 @@ def create_loader(
num_workers=1,
distributed=False,
crop_pct=None,
collate_fn=None,
):
if isinstance(input_size, tuple):
img_size = input_size[-2:]
@ -108,13 +122,16 @@ def create_loader(
# of samples per-process, will slightly alter validation results
sampler = OrderedDistributedSampler(dataset)
if collate_fn is None:
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=sampler is None and is_training,
num_workers=num_workers,
sampler=sampler,
collate_fn=fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate,
collate_fn=collate_fn,
drop_last=is_training,
)
if use_prefetcher:

@ -0,0 +1,42 @@
import numpy as np
import torch
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
x = x.long().view(-1, 1)
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
off_value = smoothing / num_classes
on_value = 1. - smoothing + off_value
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
return lam*y1 + (1. - lam)*y2
class FastCollateMixup:
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000):
self.mixup_alpha = mixup_alpha
self.label_smoothing = label_smoothing
self.num_classes = num_classes
self.mixup_enabled = True
def __call__(self, batch):
batch_size = len(batch)
lam = 1.
if self.mixup_enabled:
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
mixed = batch[i][0].astype(np.float32) * lam + \
batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam)
np.round(mixed, out=mixed)
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
return tensor, target

@ -159,7 +159,7 @@ def transforms_imagenet_train(
color_jitter=(0.4, 0.4, 0.4),
interpolation='random',
random_erasing=0.4,
random_erasing_pp=True,
random_erasing_mode='const',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
@ -183,7 +183,7 @@ def transforms_imagenet_train(
std=torch.tensor(std))
]
if random_erasing > 0.:
tfl.append(RandomErasing(random_erasing, per_pixel=random_erasing_pp, device='cpu'))
tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
return transforms.Compose(tfl)

@ -10,7 +10,7 @@ try:
except ImportError:
has_apex = False
from data import Dataset, create_loader, resolve_data_config
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
from models import create_model, resume_checkpoint
from utils import *
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
@ -66,9 +66,9 @@ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RA
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "step"')
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
help='Dropout rate (default: 0.1)')
parser.add_argument('--reprob', type=float, default=0.4, metavar='PCT',
help='Random erase prob (default: 0.4)')
help='Dropout rate (default: 0.)')
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)')
parser.add_argument('--remode', type=str, default='const',
help='Random erase mode (default: "const")')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
@ -109,6 +109,8 @@ parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA amp for mixed precision training')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
@ -119,6 +121,7 @@ parser.add_argument("--local_rank", default=0, type=int)
def main():
args = parser.parse_args()
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
@ -130,6 +133,7 @@ def main():
args.world_size = 1
r = -1
if args.distributed:
args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl',
@ -216,12 +220,16 @@ def main():
exit(1)
dataset_train = Dataset(train_dir)
collate_fn = None
if args.prefetcher and args.mixup > 0:
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
loader_train = create_loader(
dataset_train,
input_size=data_config['input_size'],
batch_size=args.batch_size,
is_training=True,
use_prefetcher=True,
use_prefetcher=args.prefetcher,
rand_erase_prob=args.reprob,
rand_erase_mode=args.remode,
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
@ -229,6 +237,7 @@ def main():
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
collate_fn=collate_fn,
)
eval_dir = os.path.join(args.data, 'validation')
@ -242,7 +251,7 @@ def main():
input_size=data_config['input_size'],
batch_size=4 * args.batch_size,
is_training=False,
use_prefetcher=True,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
@ -309,6 +318,10 @@ def train_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', use_amp=False):
if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
loader.mixup_enabled = False
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
@ -321,13 +334,15 @@ def train_epoch(
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
if args.mixup > 0.:
lam = 1.
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
lam = np.random.beta(args.mixup, args.mixup)
input.mul_(lam).add_(1 - lam, input.flip(0))
target = mixup_target(target, args.num_classes, lam, args.smoothing)
if not args.prefetcher:
input = input.cuda()
target = target.cuda()
if args.mixup > 0.:
lam = 1.
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
lam = np.random.beta(args.mixup, args.mixup)
input.mul_(lam).add_(1 - lam, input.flip(0))
target = mixup_target(target, args.num_classes, lam, args.smoothing)
output = model(input)

@ -140,19 +140,6 @@ def accuracy(output, target, topk=(1,)):
return res
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
x = x.long().view(-1, 1)
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
off_value = smoothing / num_classes
on_value = 1. - smoothing + off_value
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value)
return lam*y1 + (1. - lam)*y2
def get_outdir(path, *paths, inc=False):
outdir = os.path.join(path, *paths)
if not os.path.exists(outdir):

Loading…
Cancel
Save