From cf0c280e1bb0c32e80315e740e864019713d0009 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 6 Feb 2019 20:19:11 -0800 Subject: [PATCH] Cleanup tranforms, add custom schedulers, tweak senet34 model --- dataset.py | 9 +- models/__init__.py | 3 +- models/model_factory.py | 63 -------------- models/senet.py | 2 +- models/transforms.py | 73 ++++++++++++++++ scheduler/__init__.py | 3 + scheduler/cosine_lr.py | 72 ++++++++++++++++ scheduler/plateau_lr.py | 68 +++++++++++++++ scheduler/scheduler.py | 73 ++++++++++++++++ scheduler/step_lr.py | 48 +++++++++++ train.py | 185 +++++++++++++++++++--------------------- 11 files changed, 432 insertions(+), 167 deletions(-) create mode 100644 models/transforms.py create mode 100644 scheduler/__init__.py create mode 100644 scheduler/cosine_lr.py create mode 100644 scheduler/plateau_lr.py create mode 100644 scheduler/scheduler.py create mode 100644 scheduler/step_lr.py diff --git a/dataset.py b/dataset.py index 7191bb26..e269e60f 100644 --- a/dataset.py +++ b/dataset.py @@ -9,6 +9,7 @@ import re import torch from PIL import Image + IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg'] @@ -53,7 +54,7 @@ class Dataset(data.Dataset): def __init__( self, root, - transform=None): + transform): imgs, _, _ = find_images_and_targets(root) if len(imgs) == 0: @@ -66,8 +67,7 @@ class Dataset(data.Dataset): def __getitem__(self, index): path, target = self.imgs[index] img = Image.open(path).convert('RGB') - if self.transform is not None: - img = self.transform(img) + img = self.transform(img) if target is None: target = torch.zeros(1).long() return img, target @@ -75,9 +75,6 @@ class Dataset(data.Dataset): def __len__(self): return len(self.imgs) - def set_transform(self, transform): - self.transform = transform - def filenames(self, indices=[], basename=False): if indices: if basename: diff --git a/models/__init__.py b/models/__init__.py index 53cbf96a..13423620 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1 +1,2 @@ -from .model_factory import create_model, get_transforms_eval, get_transforms_train +from .model_factory import create_model +from .transforms import transforms_imagenet_eval, transforms_imagenet_train diff --git a/models/model_factory.py b/models/model_factory.py index 47a5217d..e43da7bf 100644 --- a/models/model_factory.py +++ b/models/model_factory.py @@ -129,66 +129,3 @@ def load_checkpoint(model, checkpoint_path): else: print("Error: No checkpoint found at %s." % checkpoint_path) - -class LeNormalize(object): - """Normalize to -1..1 in Google Inception style - """ - def __call__(self, tensor): - for t in tensor: - t.sub_(0.5).mul_(2.0) - return tensor - - -DEFAULT_CROP_PCT = 0.875 - - -def get_transforms_train(model_name, img_size=224): - if 'dpn' in model_name: - normalize = transforms.Normalize( - mean=[124 / 255, 117 / 255, 104 / 255], - std=[1 / (.0167 * 255)] * 3) - elif 'inception' in model_name: - normalize = LeNormalize() - else: - normalize = transforms.Normalize( - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - - return transforms.Compose([ - transforms.RandomResizedCrop(img_size, scale=(0.3, 1.0)), - transforms.RandomHorizontalFlip(), - transforms.ColorJitter(0.3, 0.3, 0.3), - transforms.ToTensor(), - normalize]) - - -def get_transforms_eval(model_name, img_size=224, crop_pct=None): - crop_pct = crop_pct or DEFAULT_CROP_PCT - if 'dpn' in model_name: - if crop_pct is None: - # Use default 87.5% crop for model's native img_size - # but use 100% crop for larger than native as it - # improves test time results across all models. - if img_size == 224: - scale_size = int(math.floor(img_size / DEFAULT_CROP_PCT)) - else: - scale_size = img_size - else: - scale_size = int(math.floor(img_size / crop_pct)) - normalize = transforms.Normalize( - mean=[124 / 255, 117 / 255, 104 / 255], - std=[1 / (.0167 * 255)] * 3) - elif 'inception' in model_name: - scale_size = int(math.floor(img_size / crop_pct)) - normalize = LeNormalize() - else: - scale_size = int(math.floor(img_size / crop_pct)) - normalize = transforms.Normalize( - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - - return transforms.Compose([ - transforms.Resize(scale_size, Image.BICUBIC), - transforms.CenterCrop(img_size), - transforms.ToTensor(), - normalize]) diff --git a/models/senet.py b/models/senet.py index 1d8d9056..dc40c1e8 100644 --- a/models/senet.py +++ b/models/senet.py @@ -441,7 +441,7 @@ def senet154(num_classes=1000, pretrained='imagenet'): def se_resnet18(num_classes=1000, pretrained='imagenet'): - model = SENet(SEResNetBottleneck, [2, 2, 2, 2], groups=1, reduction=16, + model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16, dropout_p=None, inplanes=64, input_3x3=False, downsample_kernel_size=1, downsample_padding=0, num_classes=num_classes) diff --git a/models/transforms.py b/models/transforms.py new file mode 100644 index 00000000..cdb84456 --- /dev/null +++ b/models/transforms.py @@ -0,0 +1,73 @@ +import torch +from torchvision import transforms +from PIL import Image +import math + + +DEFAULT_CROP_PCT = 0.875 + +IMAGENET_DPN_MEAN = [124 / 255, 117 / 255, 104 / 255] +IMAGENET_DPN_STD = [1 / (.0167 * 255)] * 3 +IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] +IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] + + +class LeNormalize(object): + """Normalize to -1..1 in Google Inception style + """ + def __call__(self, tensor): + for t in tensor: + t.sub_(0.5).mul_(2.0) + return tensor + + +def transforms_imagenet_train(model_name, img_size=224, scale=(0.08, 1.0), color_jitter=(0.3, 0.3, 0.3)): + if 'dpn' in model_name: + normalize = transforms.Normalize( + mean=IMAGENET_DPN_MEAN, + std=IMAGENET_DPN_STD) + elif 'inception' in model_name: + normalize = LeNormalize() + else: + normalize = transforms.Normalize( + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD) + + return transforms.Compose([ + transforms.RandomResizedCrop(img_size, scale=scale), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(*color_jitter), + transforms.ToTensor(), + normalize]) + + +def transforms_imagenet_eval(model_name, img_size=224, crop_pct=None): + crop_pct = crop_pct or DEFAULT_CROP_PCT + if 'dpn' in model_name: + if crop_pct is None: + # Use default 87.5% crop for model's native img_size + # but use 100% crop for larger than native as it + # improves test time results across all models. + if img_size == 224: + scale_size = int(math.floor(img_size / DEFAULT_CROP_PCT)) + else: + scale_size = img_size + else: + scale_size = int(math.floor(img_size / crop_pct)) + normalize = transforms.Normalize( + mean=IMAGENET_DPN_MEAN, + std=IMAGENET_DPN_STD) + elif 'inception' in model_name: + scale_size = int(math.floor(img_size / crop_pct)) + normalize = LeNormalize() + else: + scale_size = int(math.floor(img_size / crop_pct)) + normalize = transforms.Normalize( + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD) + + return transforms.Compose([ + transforms.Resize(scale_size, Image.BICUBIC), + transforms.CenterCrop(img_size), + transforms.ToTensor(), + normalize]) diff --git a/scheduler/__init__.py b/scheduler/__init__.py new file mode 100644 index 00000000..73f9c78d --- /dev/null +++ b/scheduler/__init__.py @@ -0,0 +1,3 @@ +from .cosine_lr import CosineLRScheduler +from .plateau_lr import PlateauLRScheduler +from .step_lr import StepLRScheduler diff --git a/scheduler/cosine_lr.py b/scheduler/cosine_lr.py new file mode 100644 index 00000000..576ec01b --- /dev/null +++ b/scheduler/cosine_lr.py @@ -0,0 +1,72 @@ +import logging +import math +import numpy as np +import torch + +from .scheduler import Scheduler + + +logger = logging.getLogger(__name__) + + +class CosineLRScheduler(Scheduler): + """ + Cosine annealing with restarts. + This is described in the paper https://arxiv.org/abs/1608.03983. + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + t_mul: float = 1., + lr_min: float = 0., + decay_rate: float = 1., + warmup_updates=0, + warmup_lr_init=0, + initialize=True) -> None: + super().__init__(optimizer, param_group_field="lr", initialize=initialize) + + assert t_initial > 0 + assert lr_min >= 0 + if t_initial == 1 and t_mul == 1 and decay_rate == 1: + logger.warning("Cosine annealing scheduler will have no effect on the learning " + "rate since t_initial = t_mul = eta_mul = 1.") + self.t_initial = t_initial + self.t_mul = t_mul + self.lr_min = lr_min + self.decay_rate = decay_rate + self.warmup_updates = warmup_updates + self.warmup_lr_init = warmup_lr_init + if self.warmup_updates: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values] + else: + self.warmup_steps = [1 for _ in self.base_values] + if self.warmup_lr_init: + super().update_groups(self.warmup_lr_init) + + def get_epoch_values(self, epoch: int): + # this scheduler doesn't update on epoch + return None + + def get_update_values(self, num_updates: int): + if num_updates < self.warmup_updates: + lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps] + else: + curr_updates = num_updates - self.warmup_updates + if self.t_mul != 1: + i = math.floor(math.log(1 - curr_updates / self.t_initial * (1 - self.t_mul), self.t_mul)) + t_i = self.t_mul ** i * self.t_initial + t_curr = curr_updates - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial + else: + i = curr_updates // self.t_initial + t_i = self.t_initial + t_curr = curr_updates - (self.t_initial * i) + + gamma = self.decay_rate ** i + lr_min = self.lr_min * gamma + lr_max_values = [v * gamma for v in self.base_values] + + lrs = [ + lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values + ] + return lrs diff --git a/scheduler/plateau_lr.py b/scheduler/plateau_lr.py new file mode 100644 index 00000000..0cad2159 --- /dev/null +++ b/scheduler/plateau_lr.py @@ -0,0 +1,68 @@ +import torch + +from .scheduler import Scheduler + + +class PlateauLRScheduler(Scheduler): + """Decay the LR by a factor every time the validation loss plateaus.""" + + def __init__(self, + optimizer, + factor=0.1, + patience=10, + verbose=False, + threshold=1e-4, + cooldown_epochs=0, + warmup_updates=0, + warmup_lr_init=0, + lr_min=0, + ): + super().__init__(optimizer, 'lr', initialize=False) + + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer.optimizer, + patience=patience, + factor=factor, + verbose=verbose, + threshold=threshold, + cooldown=cooldown_epochs, + min_lr=lr_min + ) + + self.warmup_updates = warmup_updates + self.warmup_lr_init = warmup_lr_init + + if self.warmup_updates: + self.warmup_active = warmup_updates > 0 # this state updates with num_updates + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def state_dict(self): + return { + 'best': self.lr_scheduler.best, + 'last_epoch': self.lr_scheduler.last_epoch, + } + + def load_state_dict(self, state_dict): + self.lr_scheduler.best = state_dict['best'] + if 'last_epoch' in state_dict: + self.lr_scheduler.last_epoch = state_dict['last_epoch'] + + # override the base class step fn completely + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + if val_loss is not None and not self.warmup_active: + self.lr_scheduler.step(val_loss, epoch) + else: + self.lr_scheduler.last_epoch = epoch + + def get_update_values(self, num_updates: int): + if num_updates < self.warmup_updates: + lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps] + else: + self.warmup_active = False # warmup cancelled by first update past warmup_update count + lrs = None # no change on update after warmup stage + return lrs + diff --git a/scheduler/scheduler.py b/scheduler/scheduler.py new file mode 100644 index 00000000..78e8460d --- /dev/null +++ b/scheduler/scheduler.py @@ -0,0 +1,73 @@ +from typing import Dict, Any + +import torch + + +class Scheduler: + """ Parameter Scheduler Base Class + A scheduler base class that can be used to schedule any optimizer parameter groups. + + Unlike the builtin PyTorch schedulers, this is intended to be consistently called + * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value + * At the END of each optimizer update, after incrementing the update count, to calculate next update's value + + The schedulers built on this should try to remain as stateless as possible (for simplicity). + + This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' + and -1 values for special behaviour. All epoch and update counts must be tracked in the training + code and explicitly passed in to the schedulers on the corresponding step or step_update call. + + Based on ideas from: + * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler + * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + param_group_field: str, + initialize: bool = True) -> None: + self.optimizer = optimizer + self.param_group_field = param_group_field + self._initial_param_group_field = f"initial_{param_group_field}" + if initialize: + for i, group in enumerate(self.optimizer.param_groups): + if param_group_field not in group: + raise KeyError(f"{param_group_field} missing from param_groups[{i}]") + group.setdefault(self._initial_param_group_field, group[param_group_field]) + else: + for i, group in enumerate(self.optimizer.param_groups): + if self._initial_param_group_field not in group: + raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") + self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] + self.metric = None # any point to having this for all? + self.update_groups(self.base_values) + + def state_dict(self) -> Dict[str, Any]: + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.__dict__.update(state_dict) + + def get_epoch_values(self, epoch: int): + return None + + def get_update_values(self, num_updates: int): + return None + + def step(self, epoch: int, metric: float = None) -> None: + self.metric = metric + values = self.get_epoch_values(epoch) + if values is not None: + self.update_groups(values) + + def step_update(self, num_updates: int, metric: float = None): + self.metric = metric + values = self.get_update_values(num_updates) + if values is not None: + self.update_groups(values) + + def update_groups(self, values): + if not isinstance(values, (list, tuple)): + values = [values] * len(self.optimizer.param_groups) + for param_group, value in zip(self.optimizer.param_groups, values): + param_group[self.param_group_field] = value diff --git a/scheduler/step_lr.py b/scheduler/step_lr.py new file mode 100644 index 00000000..a00659da --- /dev/null +++ b/scheduler/step_lr.py @@ -0,0 +1,48 @@ +import math +import torch + +from .scheduler import Scheduler + + +class StepLRScheduler(Scheduler): + """ + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + decay_epochs: int, + decay_rate: float = 1., + warmup_updates=0, + warmup_lr_init=0, + initialize=True) -> None: + super().__init__(optimizer, param_group_field="lr", initialize=initialize) + + self.decay_epochs = decay_epochs + self.decay_rate = decay_rate + self.warmup_updates = warmup_updates + self.warmup_lr_init = warmup_lr_init + + if self.warmup_updates: + self.warmup_active = warmup_updates > 0 # this state updates with num_updates + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def get_epoch_values(self, epoch: int): + if not self.warmup_active: + lrs = [v * (self.decay_rate ** ((epoch + 1) // self.decay_epochs)) + for v in self.base_values] + else: + lrs = None # no epoch updates while warming up + return lrs + + def get_update_values(self, num_updates: int): + if num_updates < self.warmup_updates: + lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps] + else: + self.warmup_active = False # warmup cancelled by first update past warmup_update count + lrs = None # no change on update afte warmup stage + return lrs + + diff --git a/train.py b/train.py index 4639250d..63007370 100644 --- a/train.py +++ b/train.py @@ -6,9 +6,10 @@ from collections import OrderedDict from datetime import datetime from dataset import Dataset -from models import model_factory, get_transforms_eval, get_transforms_train +from models import model_factory, transforms_imagenet_eval, transforms_imagenet_train from utils import * from optim import nadam +import scheduler import torch import torch.nn @@ -48,6 +49,8 @@ parser.add_argument('--decay-epochs', type=int, default=30, metavar='N', help='epoch interval to decay LR') parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') +parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', + help='LR scheduler (default: "step"') parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', help='Dropout rate (default: 0.1)') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', @@ -93,22 +96,9 @@ def main(): num_epochs = args.epochs torch.manual_seed(args.seed) - model = model_factory.create_model( - args.model, - pretrained=args.pretrained, - num_classes=1000, - drop_rate=args.drop, - global_pool=args.gp, - checkpoint_path=args.initial_checkpoint) - - if args.initial_batch_size: - batch_size = adjust_batch_size( - epoch=0, initial_bs=args.initial_batch_size, target_bs=args.batch_size) - print('Setting batch-size to %d' % batch_size) - dataset_train = Dataset( os.path.join(args.data, 'train'), - transform=get_transforms_train(args.model)) + transform=transforms_imagenet_train(args.model)) loader_train = data.DataLoader( dataset_train, @@ -119,7 +109,7 @@ def main(): dataset_eval = Dataset( os.path.join(args.data, 'validation'), - transform=get_transforms_eval(args.model)) + transform=transforms_imagenet_eval(args.model)) loader_eval = data.DataLoader( dataset_eval, @@ -128,38 +118,17 @@ def main(): num_workers=args.workers ) - train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss() - train_loss_fn = train_loss_fn.cuda() - validate_loss_fn = validate_loss_fn.cuda() - - if args.opt.lower() == 'sgd': - optimizer = optim.SGD( - model.parameters(), lr=args.lr, - momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) - elif args.opt.lower() == 'adam': - optimizer = optim.Adam( - model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) - elif args.opt.lower() == 'nadam': - optimizer = nadam.Nadam( - model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) - elif args.opt.lower() == 'adadelta': - optimizer = optim.Adadelta( - model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) - elif args.opt.lower() == 'rmsprop': - optimizer = optim.RMSprop( - model.parameters(), lr=args.lr, alpha=0.9, eps=args.opt_eps, - momentum=args.momentum, weight_decay=args.weight_decay) - else: - assert False and "Invalid optimizer" - exit(1) - - if not args.decay_epochs: - lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=8) - else: - lr_scheduler = None + model = model_factory.create_model( + args.model, + pretrained=args.pretrained, + num_classes=1000, + drop_rate=args.drop, + global_pool=args.gp, + checkpoint_path=args.initial_checkpoint) # optionally resume from a checkpoint start_epoch = 0 if args.start_epoch is None else args.start_epoch + optimizer_state = None if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) @@ -174,7 +143,7 @@ def main(): new_state_dict[name] = v model.load_state_dict(new_state_dict) if 'optimizer' in checkpoint: - optimizer.load_state_dict(checkpoint['optimizer']) + optimizer_state = checkpoint['optimizer'] print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) start_epoch = checkpoint['epoch'] if args.start_epoch is None else args.start_epoch else: @@ -183,55 +152,73 @@ def main(): print("=> no checkpoint found at '{}'".format(args.resume)) return False - saver = CheckpointSaver(checkpoint_dir=output_dir) - if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() + train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss() + train_loss_fn = train_loss_fn.cuda() + validate_loss_fn = validate_loss_fn.cuda() + + if args.opt.lower() == 'sgd': + optimizer = optim.SGD( + model.parameters(), lr=args.lr, + momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) + elif args.opt.lower() == 'adam': + optimizer = optim.Adam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) + elif args.opt.lower() == 'nadam': + optimizer = nadam.Nadam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) + elif args.opt.lower() == 'adadelta': + optimizer = optim.Adadelta( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) + elif args.opt.lower() == 'rmsprop': + optimizer = optim.RMSprop( + model.parameters(), lr=args.lr, alpha=0.9, eps=args.opt_eps, + momentum=args.momentum, weight_decay=args.weight_decay) + else: + assert False and "Invalid optimizer" + exit(1) + + if optimizer_state is not None: + optimizer.load_state_dict(optimizer_state) + + if args.sched == 'cosine': + lr_scheduler = scheduler.CosineLRScheduler( + optimizer, + t_initial=13 * len(loader_train), + t_mul=2.0, + lr_min=0, + decay_rate=0.5, + warmup_lr_init=1e-4, + warmup_updates=len(loader_train) + ) + else: + lr_scheduler = scheduler.StepLRScheduler( + optimizer, + decay_epochs=args.decay_epochs, + decay_rate=args.decay_rate, + ) + + saver = CheckpointSaver(checkpoint_dir=output_dir) best_loss = None try: for epoch in range(start_epoch, num_epochs): - if args.decay_epochs: - adjust_learning_rate( - optimizer, epoch, initial_lr=args.lr, - decay_rate=args.decay_rate, decay_epochs=args.decay_epochs) - - if args.initial_batch_size: - next_batch_size = adjust_batch_size( - epoch, initial_bs=args.initial_batch_size, target_bs=args.batch_size) - if next_batch_size > batch_size: - print("Changing batch size from %d to %d" % (batch_size, next_batch_size)) - batch_size = next_batch_size - loader_train = data.DataLoader( - dataset_train, - batch_size=batch_size, - pin_memory=True, - shuffle=True, - # sampler=sampler, - num_workers=args.workers) train_metrics = train_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, - saver=saver, output_dir=output_dir) + lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir) - step = epoch * len(loader_train) eval_metrics = validate( - step, model, loader_eval, validate_loss_fn, args, - output_dir=output_dir) + model, loader_eval, validate_loss_fn, args) if lr_scheduler is not None: - lr_scheduler.step(eval_metrics['eval_loss']) + lr_scheduler.step(epoch, eval_metrics['eval_loss']) - rowd = OrderedDict(epoch=epoch) - rowd.update(train_metrics) - rowd.update(eval_metrics) - with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf: - dw = csv.DictWriter(cf, fieldnames=rowd.keys()) - if best_loss is None: # first iteration (epoch == 1 can't be used) - dw.writeheader() - dw.writerow(rowd) + update_summary( + epoch, train_metrics, eval_metrics, output_dir, write_header=best_loss is None) # save proper checkpoint with eval metric best_loss = saver.save_checkpoint({ @@ -252,9 +239,8 @@ def main(): def train_epoch( epoch, model, loader, optimizer, loss_fn, args, - saver=None, output_dir=''): + lr_scheduler=None, saver=None, output_dir=''): - epoch_step = (epoch - 1) * len(loader) batch_time_m = AverageMeter() data_time_m = AverageMeter() losses_m = AverageMeter() @@ -263,9 +249,9 @@ def train_epoch( end = time.time() last_idx = len(loader) - 1 + num_updates = epoch * len(loader) for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx - step = epoch_step + batch_idx data_time_m.update(time.time() - end) input = input.cuda() @@ -283,20 +269,27 @@ def train_epoch( loss.backward() optimizer.step() + num_updates += 1 + batch_time_m.update(time.time() - end) if last_batch or batch_idx % args.log_interval == 0: + lrl = [param_group['lr'] for param_group in optimizer.param_groups] + lr = sum(lrl) / len(lrl) + print('Train: {} [{}/{} ({:.0f}%)] ' 'Loss: {loss.val:.6f} ({loss.avg:.4f}) ' 'Time: {batch_time.val:.3f}s, {rate:.3f}/s ' '({batch_time.avg:.3f}s, {rate_avg:.3f}/s) ' + 'LR: {lr:.4f} ' 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( epoch, - batch_idx * len(input), len(loader.sampler), + batch_idx, len(loader), 100. * batch_idx / last_idx, loss=losses_m, batch_time=batch_time_m, rate=input.size(0) / batch_time_m.val, rate_avg=input.size(0) / batch_time_m.avg, + lr=lr, data_time=data_time_m)) if args.save_images: @@ -319,12 +312,15 @@ def train_epoch( epoch=save_epoch, batch_idx=batch_idx) + if lr_scheduler is not None: + lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) + end = time.time() return OrderedDict([('train_loss', losses_m.avg)]) -def validate(step, model, loader, loss_fn, args, output_dir=''): +def validate(model, loader, loss_fn, args): batch_time_m = AverageMeter() losses_m = AverageMeter() prec1_m = AverageMeter() @@ -345,7 +341,6 @@ def validate(step, model, loader, loss_fn, args, output_dir=''): target = target.cuda() output = model(input) - if isinstance(output, (tuple, list)): output = output[0] @@ -381,17 +376,15 @@ def validate(step, model, loader, loss_fn, args, output_dir=''): return metrics -def adjust_learning_rate(optimizer, epoch, initial_lr, decay_rate=0.1, decay_epochs=30): - """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" - lr = initial_lr * (decay_rate ** (epoch // decay_epochs)) - print('Setting LR to', lr) - for param_group in optimizer.param_groups: - param_group['lr'] = lr - - -def adjust_batch_size(epoch, initial_bs, target_bs, decay_epochs=1): - batch_size = min(target_bs, initial_bs * (2 ** (epoch // decay_epochs))) - return batch_size +def update_summary(epoch, train_metrics, eval_metrics, output_dir, write_header=False): + rowd = OrderedDict(epoch=epoch) + rowd.update(train_metrics) + rowd.update(eval_metrics) + with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf: + dw = csv.DictWriter(cf, fieldnames=rowd.keys()) + if write_header: # first iteration (epoch == 1 can't be used) + dw.writeheader() + dw.writerow(rowd) if __name__ == '__main__':