diff --git a/sotabench_setup.sh b/sotabench_setup.sh index a3519c26..b3eee0f8 100755 --- a/sotabench_setup.sh +++ b/sotabench_setup.sh @@ -7,7 +7,8 @@ pip install -r requirements-sotabench.txt apt-get update apt-get install -y libjpeg-dev zlib1g-dev libpng-dev libwebp-dev pip uninstall -y pillow -CC="cc -mavx2" pip install -U --force-reinstall pillow-simd +CFLAGS="${CFLAGS} -mavx2" pip install -U --no-cache-dir --force-reinstall --no-binary :all:--compile https://github.com/mrT23/pillow-simd/zipball/simd/7.0.x +#CC="cc -mavx2" pip install -U --force-reinstall pillow-simd # FIXME this shouldn't be needed but sb dataset upload functionality doesn't seem to work apt-get install wget diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index ef4a0aec..33e4907f 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -1,10 +1,13 @@ -from .nadam import Nadam -from .rmsprop_tf import RMSpropTF +from .adamp import AdamP from .adamw import AdamW -from .radam import RAdam +from .adafactor import Adafactor +from .adahessian import Adahessian +from .lookahead import Lookahead +from .nadam import Nadam from .novograd import NovoGrad from .nvnovograd import NvNovoGrad -from .lookahead import Lookahead -from .adamp import AdamP +from .radam import RAdam +from .rmsprop_tf import RMSpropTF from .sgdp import SGDP -from .optim_factory import create_optimizer + +from .optim_factory import create_optimizer \ No newline at end of file diff --git a/timm/optim/adafactor.py b/timm/optim/adafactor.py new file mode 100644 index 00000000..088ce3ac --- /dev/null +++ b/timm/optim/adafactor.py @@ -0,0 +1,174 @@ +""" Adafactor Optimizer + +Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py + +Original header/copyright below. + +""" +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +import math + + +class Adafactor(torch.optim.Optimizer): + """Implements Adafactor algorithm. + This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` + (see https://arxiv.org/abs/1804.04235) + + Note that this optimizer internally adjusts the learning rate depending on the + *scale_parameter*, *relative_step* and *warmup_init* options. + + To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): external learning rate (default: None) + eps (tuple[float, float]): regularization constants for square gradient + and parameter scale respectively (default: (1e-30, 1e-3)) + clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0) + decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8) + beta1 (float): coefficient used for computing running averages of gradient (default: None) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True) + relative_step (bool): if True, time-dependent learning rate is computed + instead of external learning rate (default: True) + warmup_init (bool): time-dependent learning rate computation depends on + whether warm-up initialization is being used (default: False) + """ + + def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0, + decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False): + relative_step = lr is None + if warmup_init and not relative_step: + raise ValueError('warmup_init requires relative_step=True') + + beta1 = None if betas is None else betas[0] # make it compat with standard betas arg + defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate, + beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter, + relative_step=relative_step, warmup_init=warmup_init) + super(Adafactor, self).__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + if param_group['relative_step']: + min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2 + lr_t = min(min_step, 1.0 / math.sqrt(param_state['step'])) + param_scale = 1.0 + if param_group['scale_parameter']: + param_scale = max(param_group['eps_scale'], param_state['RMS']) + param_group['lr'] = lr_t * param_scale + return param_group['lr'] + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group['beta1'] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError('Adafactor does not support sparse gradients.') + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state['step'] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(grad) + if factored: + state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad) + state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state['exp_avg_sq'] = torch.zeros_like(grad) + + state['RMS'] = 0 + else: + if use_first_moment: + state['exp_avg'] = state['exp_avg'].to(grad) + if factored: + state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) + state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) + else: + state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) + + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state['step'] += 1 + state['RMS'] = self._rms(p_data_fp32) + lr_t = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) + update = grad ** 2 + group['eps'] + if factored: + exp_avg_sq_row = state['exp_avg_sq_row'] + exp_avg_sq_col = state['exp_avg_sq_col'] + + exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1)) + exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2)) + #exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+ + #exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state['exp_avg_sq'] + + exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update) + #exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+ + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0)) + update.mul_(lr_t) + + if use_first_moment: + exp_avg = state['exp_avg'] + exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update) + #exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+ + update = exp_avg + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32) + #p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+ + + p_data_fp32.add_(-update) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + return loss \ No newline at end of file diff --git a/timm/optim/adahessian.py b/timm/optim/adahessian.py new file mode 100644 index 00000000..985c67ca --- /dev/null +++ b/timm/optim/adahessian.py @@ -0,0 +1,156 @@ +""" AdaHessian Optimizer + +Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py +Originally licensed MIT, Copyright 2020, David Samuel +""" +import torch + + +class Adahessian(torch.optim.Optimizer): + """ + Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning" + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): learning rate (default: 0.1) + betas ((float, float), optional): coefficients used for computing running averages of gradient and the + squared hessian trace (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0) + hessian_power (float, optional): exponent of the hessian trace (default: 1.0) + update_each (int, optional): compute the hessian trace approximation only after *this* number of steps + (to save time) (default: 1) + n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1) + """ + + def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, + hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= hessian_power <= 1.0: + raise ValueError(f"Invalid Hessian power value: {hessian_power}") + + self.n_samples = n_samples + self.update_each = update_each + self.avg_conv_kernel = avg_conv_kernel + + # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training + self.seed = 2147483647 + self.generator = torch.Generator().manual_seed(self.seed) + + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power) + super(Adahessian, self).__init__(params, defaults) + + for p in self.get_params(): + p.hess = 0.0 + self.state[p]["hessian step"] = 0 + + @property + def is_second_order(self): + return True + + def get_params(self): + """ + Gets all parameters in all param_groups with gradients + """ + + return (p for group in self.param_groups for p in group['params'] if p.requires_grad) + + def zero_hessian(self): + """ + Zeros out the accumalated hessian traces. + """ + + for p in self.get_params(): + if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0: + p.hess.zero_() + + @torch.no_grad() + def set_hessian(self): + """ + Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter. + """ + + params = [] + for p in filter(lambda p: p.grad is not None, self.get_params()): + if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step + params.append(p) + self.state[p]["hessian step"] += 1 + + if len(params) == 0: + return + + if self.generator.device != params[0].device: # hackish way of casting the generator to the right device + self.generator = torch.Generator(params[0].device).manual_seed(self.seed) + + grads = [p.grad for p in params] + + for i in range(self.n_samples): + # Rademacher distribution {-1.0, 1.0} + zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params] + h_zs = torch.autograd.grad( + grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1) + for h_z, z, p in zip(h_zs, zs, params): + p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step. + Arguments: + closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None) + """ + + loss = None + if closure is not None: + loss = closure() + + self.zero_hessian() + self.set_hessian() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None or p.hess is None: + continue + + if self.avg_conv_kernel and p.dim() == 4: + p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone() + + # Perform correct stepweight decay as in AdamW + p.mul_(1 - group['lr'] * group['weight_decay']) + + state = self.state[p] + + # State initialization + if len(state) == 1: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of Hessian diagonal square values + state['exp_hessian_diag_sq'] = torch.zeros_like(p) + + exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq'] + beta1, beta2 = group['betas'] + state['step'] += 1 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) + exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + k = group['hessian_power'] + denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps']) + + # make update + step_size = group['lr'] / bias_correction1 + p.addcdiv_(exp_avg, denom, value=-step_size) + + return loss diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 7ae85120..c53be368 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -3,7 +3,18 @@ Hacked together by / Copyright 2020 Ross Wightman """ import torch from torch import optim as optim -from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead, AdamP, SGDP + +from .adafactor import Adafactor +from .adahessian import Adahessian +from .adamp import AdamP +from .lookahead import Lookahead +from .nadam import Nadam +from .novograd import NovoGrad +from .nvnovograd import NvNovoGrad +from .radam import RAdam +from .rmsprop_tf import RMSpropTF +from .sgdp import SGDP + try: from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD has_apex = True @@ -29,11 +40,6 @@ def add_weight_decay(model, weight_decay=1e-5, skip_list=()): def create_optimizer(args, model, filter_bias_and_bn=True): opt_lower = args.opt.lower() weight_decay = args.weight_decay - if 'adamw' in opt_lower or 'radam' in opt_lower: - # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay - # I don't believe they follow the paper or original Torch7 impl which schedules weight - # decay based on the ratio of current_lr/initial_lr - weight_decay /= args.lr if weight_decay and filter_bias_and_bn: parameters = add_weight_decay(model, weight_decay) weight_decay = 0. @@ -43,66 +49,59 @@ def create_optimizer(args, model, filter_bias_and_bn=True): if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' + opt_args = dict(lr=args.lr, weight_decay=weight_decay) + if args.opt_eps is not None: + opt_args['eps'] = args.opt_eps + if args.opt_betas is not None: + opt_args['betas'] = args.opt_betas + opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': - optimizer = optim.SGD( - parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) + optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'momentum': - optimizer = optim.SGD( - parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) + optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'adam': - optimizer = optim.Adam( - parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) + optimizer = optim.Adam(parameters, **opt_args) elif opt_lower == 'adamw': - optimizer = AdamW( - parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) + optimizer = optim.AdamW(parameters, **opt_args) elif opt_lower == 'nadam': - optimizer = Nadam( - parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) + optimizer = Nadam(parameters, **opt_args) elif opt_lower == 'radam': - optimizer = RAdam( - parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) + optimizer = RAdam(parameters, **opt_args) elif opt_lower == 'adamp': - optimizer = AdamP( - parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps, - delta=0.1, wd_ratio=0.01, nesterov=True) + optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) elif opt_lower == 'sgdp': - optimizer = SGDP( - parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, - eps=args.opt_eps, nesterov=True) + optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'adadelta': - optimizer = optim.Adadelta( - parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) + optimizer = optim.Adadelta(parameters, **opt_args) + elif opt_lower == 'adafactor': + if not args.lr: + opt_args['lr'] = None + optimizer = Adafactor(parameters, **opt_args) + elif opt_lower == 'adahessian': + optimizer = Adahessian(parameters, **opt_args) elif opt_lower == 'rmsprop': - optimizer = optim.RMSprop( - parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, - momentum=args.momentum, weight_decay=weight_decay) + optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) elif opt_lower == 'rmsproptf': - optimizer = RMSpropTF( - parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, - momentum=args.momentum, weight_decay=weight_decay) + optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) elif opt_lower == 'novograd': - optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) + optimizer = NovoGrad(parameters, **opt_args) elif opt_lower == 'nvnovograd': - optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) + optimizer = NvNovoGrad(parameters, **opt_args) elif opt_lower == 'fusedsgd': - optimizer = FusedSGD( - parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) + optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'fusedmomentum': - optimizer = FusedSGD( - parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) + optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'fusedadam': - optimizer = FusedAdam( - parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps) + optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) elif opt_lower == 'fusedadamw': - optimizer = FusedAdam( - parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps) + optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) elif opt_lower == 'fusedlamb': - optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) + optimizer = FusedLAMB(parameters, **opt_args) elif opt_lower == 'fusednovograd': - optimizer = FusedNovoGrad( - parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps) + opt_args.setdefault('betas', (0.95, 0.98)) + optimizer = FusedNovoGrad(parameters, **opt_args) else: assert False and "Invalid optimizer" raise ValueError diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py index d972002c..bcd29f58 100644 --- a/timm/utils/cuda.py +++ b/timm/utils/cuda.py @@ -15,10 +15,10 @@ except ImportError: class ApexScaler: state_dict_key = "amp" - def __call__(self, loss, optimizer, clip_grad=None, parameters=None): + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - if clip_grad: + scaled_loss.backward(create_graph=create_graph) + if clip_grad is not None: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad) optimizer.step() @@ -37,9 +37,9 @@ class NativeScaler: def __init__(self): self._scaler = torch.cuda.amp.GradScaler() - def __call__(self, loss, optimizer, clip_grad=None, parameters=None): - self._scaler.scale(loss).backward() - if clip_grad: + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): + self._scaler.scale(loss).backward(create_graph=create_graph) + if clip_grad is not None: assert parameters is not None self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place torch.nn.utils.clip_grad_norm_(parameters, clip_grad) diff --git a/train.py b/train.py index 3a235faf..ef3adf85 100755 --- a/train.py +++ b/train.py @@ -98,12 +98,18 @@ parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, defau # Optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', help='Optimizer (default: "sgd"') -parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', - help='Optimizer Epsilon (default: 1e-8)') +parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: None, use opt default)') +parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', + help='Optimizer Betas (default: None, use opt default)') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', - help='SGD momentum (default: 0.9)') + help='Optimizer momentum (default: 0.9)') parser.add_argument('--weight-decay', type=float, default=0.0001, help='weight decay (default: 0.0001)') +parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') + + # Learning rate schedule parameters parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', @@ -595,6 +601,7 @@ def train_epoch( elif mixup_fn is not None: mixup_fn.mixup_enabled = False + second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order batch_time_m = AverageMeter() data_time_m = AverageMeter() losses_m = AverageMeter() @@ -623,9 +630,12 @@ def train_epoch( optimizer.zero_grad() if loss_scaler is not None: - loss_scaler(loss, optimizer) + loss_scaler( + loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) else: - loss.backward() + loss.backward(create_graph=second_order) + if args.clip_grad is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) optimizer.step() torch.cuda.synchronize()