Add Adafactor and Adahessian optimizers, cleanup optimizer arg passing, add gradient clipping support.

pull/256/head
Ross Wightman 4 years ago
parent fcb6258877
commit 80078c47bb

@ -7,7 +7,8 @@ pip install -r requirements-sotabench.txt
apt-get update apt-get update
apt-get install -y libjpeg-dev zlib1g-dev libpng-dev libwebp-dev apt-get install -y libjpeg-dev zlib1g-dev libpng-dev libwebp-dev
pip uninstall -y pillow pip uninstall -y pillow
CC="cc -mavx2" pip install -U --force-reinstall pillow-simd CFLAGS="${CFLAGS} -mavx2" pip install -U --no-cache-dir --force-reinstall --no-binary :all:--compile https://github.com/mrT23/pillow-simd/zipball/simd/7.0.x
#CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
# FIXME this shouldn't be needed but sb dataset upload functionality doesn't seem to work # FIXME this shouldn't be needed but sb dataset upload functionality doesn't seem to work
apt-get install wget apt-get install wget

@ -1,10 +1,13 @@
from .nadam import Nadam from .adamp import AdamP
from .rmsprop_tf import RMSpropTF
from .adamw import AdamW from .adamw import AdamW
from .radam import RAdam from .adafactor import Adafactor
from .adahessian import Adahessian
from .lookahead import Lookahead
from .nadam import Nadam
from .novograd import NovoGrad from .novograd import NovoGrad
from .nvnovograd import NvNovoGrad from .nvnovograd import NvNovoGrad
from .lookahead import Lookahead from .radam import RAdam
from .adamp import AdamP from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP from .sgdp import SGDP
from .optim_factory import create_optimizer from .optim_factory import create_optimizer

@ -0,0 +1,174 @@
""" Adafactor Optimizer
Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
Original header/copyright below.
"""
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import math
class Adafactor(torch.optim.Optimizer):
"""Implements Adafactor algorithm.
This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
(see https://arxiv.org/abs/1804.04235)
Note that this optimizer internally adjusts the learning rate depending on the
*scale_parameter*, *relative_step* and *warmup_init* options.
To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
`relative_step=False`.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): external learning rate (default: None)
eps (tuple[float, float]): regularization constants for square gradient
and parameter scale respectively (default: (1e-30, 1e-3))
clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
beta1 (float): coefficient used for computing running averages of gradient (default: None)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
relative_step (bool): if True, time-dependent learning rate is computed
instead of external learning rate (default: True)
warmup_init (bool): time-dependent learning rate computation depends on
whether warm-up initialization is being used (default: False)
"""
def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0,
decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False):
relative_step = lr is None
if warmup_init and not relative_step:
raise ValueError('warmup_init requires relative_step=True')
beta1 = None if betas is None else betas[0] # make it compat with standard betas arg
defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate,
beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter,
relative_step=relative_step, warmup_init=warmup_init)
super(Adafactor, self).__init__(params, defaults)
@staticmethod
def _get_lr(param_group, param_state):
if param_group['relative_step']:
min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
lr_t = min(min_step, 1.0 / math.sqrt(param_state['step']))
param_scale = 1.0
if param_group['scale_parameter']:
param_scale = max(param_group['eps_scale'], param_state['RMS'])
param_group['lr'] = lr_t * param_scale
return param_group['lr']
@staticmethod
def _get_options(param_group, param_shape):
factored = len(param_shape) >= 2
use_first_moment = param_group['beta1'] is not None
return factored, use_first_moment
@staticmethod
def _rms(tensor):
return tensor.norm(2) / (tensor.numel() ** 0.5)
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError('Adafactor does not support sparse gradients.')
state = self.state[p]
grad_shape = grad.shape
factored, use_first_moment = self._get_options(group, grad_shape)
# State Initialization
if len(state) == 0:
state['step'] = 0
if use_first_moment:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(grad)
if factored:
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
else:
state['exp_avg_sq'] = torch.zeros_like(grad)
state['RMS'] = 0
else:
if use_first_moment:
state['exp_avg'] = state['exp_avg'].to(grad)
if factored:
state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
else:
state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
p_data_fp32 = p.data
if p.data.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
state['step'] += 1
state['RMS'] = self._rms(p_data_fp32)
lr_t = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
update = grad ** 2 + group['eps']
if factored:
exp_avg_sq_row = state['exp_avg_sq_row']
exp_avg_sq_col = state['exp_avg_sq_col']
exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))
#exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+
#exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state['exp_avg_sq']
exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
#exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0))
update.mul_(lr_t)
if use_first_moment:
exp_avg = state['exp_avg']
exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update)
#exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+
update = exp_avg
if group['weight_decay'] != 0:
p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32)
#p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+
p_data_fp32.add_(-update)
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)
return loss

@ -0,0 +1,156 @@
""" AdaHessian Optimizer
Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
Originally licensed MIT, Copyright 2020, David Samuel
"""
import torch
class Adahessian(torch.optim.Optimizer):
"""
Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): learning rate (default: 0.1)
betas ((float, float), optional): coefficients used for computing running averages of gradient and the
squared hessian trace (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
hessian_power (float, optional): exponent of the hessian trace (default: 1.0)
update_each (int, optional): compute the hessian trace approximation only after *this* number of steps
(to save time) (default: 1)
n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
"""
def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0,
hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= hessian_power <= 1.0:
raise ValueError(f"Invalid Hessian power value: {hessian_power}")
self.n_samples = n_samples
self.update_each = update_each
self.avg_conv_kernel = avg_conv_kernel
# use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
self.seed = 2147483647
self.generator = torch.Generator().manual_seed(self.seed)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power)
super(Adahessian, self).__init__(params, defaults)
for p in self.get_params():
p.hess = 0.0
self.state[p]["hessian step"] = 0
@property
def is_second_order(self):
return True
def get_params(self):
"""
Gets all parameters in all param_groups with gradients
"""
return (p for group in self.param_groups for p in group['params'] if p.requires_grad)
def zero_hessian(self):
"""
Zeros out the accumalated hessian traces.
"""
for p in self.get_params():
if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0:
p.hess.zero_()
@torch.no_grad()
def set_hessian(self):
"""
Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
"""
params = []
for p in filter(lambda p: p.grad is not None, self.get_params()):
if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step
params.append(p)
self.state[p]["hessian step"] += 1
if len(params) == 0:
return
if self.generator.device != params[0].device: # hackish way of casting the generator to the right device
self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
grads = [p.grad for p in params]
for i in range(self.n_samples):
# Rademacher distribution {-1.0, 1.0}
zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]
h_zs = torch.autograd.grad(
grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1)
for h_z, z, p in zip(h_zs, zs, params):
p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z)
@torch.no_grad()
def step(self, closure=None):
"""
Performs a single optimization step.
Arguments:
closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
"""
loss = None
if closure is not None:
loss = closure()
self.zero_hessian()
self.set_hessian()
for group in self.param_groups:
for p in group['params']:
if p.grad is None or p.hess is None:
continue
if self.avg_conv_kernel and p.dim() == 4:
p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone()
# Perform correct stepweight decay as in AdamW
p.mul_(1 - group['lr'] * group['weight_decay'])
state = self.state[p]
# State initialization
if len(state) == 1:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of Hessian diagonal square values
state['exp_hessian_diag_sq'] = torch.zeros_like(p)
exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2)
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
k = group['hessian_power']
denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps'])
# make update
step_size = group['lr'] / bias_correction1
p.addcdiv_(exp_avg, denom, value=-step_size)
return loss

@ -3,7 +3,18 @@ Hacked together by / Copyright 2020 Ross Wightman
""" """
import torch import torch
from torch import optim as optim from torch import optim as optim
from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead, AdamP, SGDP
from .adafactor import Adafactor
from .adahessian import Adahessian
from .adamp import AdamP
from .lookahead import Lookahead
from .nadam import Nadam
from .novograd import NovoGrad
from .nvnovograd import NvNovoGrad
from .radam import RAdam
from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP
try: try:
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
has_apex = True has_apex = True
@ -29,11 +40,6 @@ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
def create_optimizer(args, model, filter_bias_and_bn=True): def create_optimizer(args, model, filter_bias_and_bn=True):
opt_lower = args.opt.lower() opt_lower = args.opt.lower()
weight_decay = args.weight_decay weight_decay = args.weight_decay
if 'adamw' in opt_lower or 'radam' in opt_lower:
# Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay
# I don't believe they follow the paper or original Torch7 impl which schedules weight
# decay based on the ratio of current_lr/initial_lr
weight_decay /= args.lr
if weight_decay and filter_bias_and_bn: if weight_decay and filter_bias_and_bn:
parameters = add_weight_decay(model, weight_decay) parameters = add_weight_decay(model, weight_decay)
weight_decay = 0. weight_decay = 0.
@ -43,66 +49,59 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
if 'fused' in opt_lower: if 'fused' in opt_lower:
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
if args.opt_eps is not None:
opt_args['eps'] = args.opt_eps
if args.opt_betas is not None:
opt_args['betas'] = args.opt_betas
opt_split = opt_lower.split('_') opt_split = opt_lower.split('_')
opt_lower = opt_split[-1] opt_lower = opt_split[-1]
if opt_lower == 'sgd' or opt_lower == 'nesterov': if opt_lower == 'sgd' or opt_lower == 'nesterov':
optimizer = optim.SGD( optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
elif opt_lower == 'momentum': elif opt_lower == 'momentum':
optimizer = optim.SGD( optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False)
elif opt_lower == 'adam': elif opt_lower == 'adam':
optimizer = optim.Adam( optimizer = optim.Adam(parameters, **opt_args)
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'adamw': elif opt_lower == 'adamw':
optimizer = AdamW( optimizer = optim.AdamW(parameters, **opt_args)
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'nadam': elif opt_lower == 'nadam':
optimizer = Nadam( optimizer = Nadam(parameters, **opt_args)
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'radam': elif opt_lower == 'radam':
optimizer = RAdam( optimizer = RAdam(parameters, **opt_args)
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'adamp': elif opt_lower == 'adamp':
optimizer = AdamP( optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps,
delta=0.1, wd_ratio=0.01, nesterov=True)
elif opt_lower == 'sgdp': elif opt_lower == 'sgdp':
optimizer = SGDP( optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay,
eps=args.opt_eps, nesterov=True)
elif opt_lower == 'adadelta': elif opt_lower == 'adadelta':
optimizer = optim.Adadelta( optimizer = optim.Adadelta(parameters, **opt_args)
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) elif opt_lower == 'adafactor':
if not args.lr:
opt_args['lr'] = None
optimizer = Adafactor(parameters, **opt_args)
elif opt_lower == 'adahessian':
optimizer = Adahessian(parameters, **opt_args)
elif opt_lower == 'rmsprop': elif opt_lower == 'rmsprop':
optimizer = optim.RMSprop( optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps,
momentum=args.momentum, weight_decay=weight_decay)
elif opt_lower == 'rmsproptf': elif opt_lower == 'rmsproptf':
optimizer = RMSpropTF( optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps,
momentum=args.momentum, weight_decay=weight_decay)
elif opt_lower == 'novograd': elif opt_lower == 'novograd':
optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) optimizer = NovoGrad(parameters, **opt_args)
elif opt_lower == 'nvnovograd': elif opt_lower == 'nvnovograd':
optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) optimizer = NvNovoGrad(parameters, **opt_args)
elif opt_lower == 'fusedsgd': elif opt_lower == 'fusedsgd':
optimizer = FusedSGD( optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
elif opt_lower == 'fusedmomentum': elif opt_lower == 'fusedmomentum':
optimizer = FusedSGD( optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False)
elif opt_lower == 'fusedadam': elif opt_lower == 'fusedadam':
optimizer = FusedAdam( optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'fusedadamw': elif opt_lower == 'fusedadamw':
optimizer = FusedAdam( optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'fusedlamb': elif opt_lower == 'fusedlamb':
optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) optimizer = FusedLAMB(parameters, **opt_args)
elif opt_lower == 'fusednovograd': elif opt_lower == 'fusednovograd':
optimizer = FusedNovoGrad( opt_args.setdefault('betas', (0.95, 0.98))
parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps) optimizer = FusedNovoGrad(parameters, **opt_args)
else: else:
assert False and "Invalid optimizer" assert False and "Invalid optimizer"
raise ValueError raise ValueError

@ -15,10 +15,10 @@ except ImportError:
class ApexScaler: class ApexScaler:
state_dict_key = "amp" state_dict_key = "amp"
def __call__(self, loss, optimizer, clip_grad=None, parameters=None): def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False):
with amp.scale_loss(loss, optimizer) as scaled_loss: with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward(create_graph=create_graph)
if clip_grad: if clip_grad is not None:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad) torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad)
optimizer.step() optimizer.step()
@ -37,9 +37,9 @@ class NativeScaler:
def __init__(self): def __init__(self):
self._scaler = torch.cuda.amp.GradScaler() self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, parameters=None): def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False):
self._scaler.scale(loss).backward() self._scaler.scale(loss).backward(create_graph=create_graph)
if clip_grad: if clip_grad is not None:
assert parameters is not None assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
torch.nn.utils.clip_grad_norm_(parameters, clip_grad) torch.nn.utils.clip_grad_norm_(parameters, clip_grad)

@ -98,12 +98,18 @@ parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, defau
# Optimizer parameters # Optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"') help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)') help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)') help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.0001, parser.add_argument('--weight-decay', type=float, default=0.0001,
help='weight decay (default: 0.0001)') help='weight decay (default: 0.0001)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
# Learning rate schedule parameters # Learning rate schedule parameters
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
@ -595,6 +601,7 @@ def train_epoch(
elif mixup_fn is not None: elif mixup_fn is not None:
mixup_fn.mixup_enabled = False mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
batch_time_m = AverageMeter() batch_time_m = AverageMeter()
data_time_m = AverageMeter() data_time_m = AverageMeter()
losses_m = AverageMeter() losses_m = AverageMeter()
@ -623,9 +630,12 @@ def train_epoch(
optimizer.zero_grad() optimizer.zero_grad()
if loss_scaler is not None: if loss_scaler is not None:
loss_scaler(loss, optimizer) loss_scaler(
loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
else: else:
loss.backward() loss.backward(create_graph=second_order)
if args.clip_grad is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
optimizer.step() optimizer.step()
torch.cuda.synchronize() torch.cuda.synchronize()

Loading…
Cancel
Save