commit
4d284017b8
@ -0,0 +1,704 @@
|
||||
""" Optimzier Tests
|
||||
|
||||
These tests were adapted from PyTorch' optimizer tests.
|
||||
|
||||
"""
|
||||
import math
|
||||
import pytest
|
||||
import functools
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.autograd import Variable
|
||||
from timm.scheduler import PlateauLRScheduler
|
||||
|
||||
from timm.optim import create_optimizer_v2
|
||||
|
||||
|
||||
# HACK relying on internal PyTorch test functionality for comparisons that I don't want to write
|
||||
torch_tc = TestCase()
|
||||
|
||||
|
||||
def _test_basic_cases_template(weight, bias, input, constructor, scheduler_constructors):
|
||||
weight = Variable(weight, requires_grad=True)
|
||||
bias = Variable(bias, requires_grad=True)
|
||||
input = Variable(input)
|
||||
optimizer = constructor(weight, bias)
|
||||
schedulers = []
|
||||
for scheduler_constructor in scheduler_constructors:
|
||||
schedulers.append(scheduler_constructor(optimizer))
|
||||
|
||||
# to check if the optimizer can be printed as a string
|
||||
optimizer.__repr__()
|
||||
|
||||
def fn():
|
||||
optimizer.zero_grad()
|
||||
y = weight.mv(input)
|
||||
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
|
||||
y = y.cuda(bias.get_device())
|
||||
loss = (y + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
initial_value = fn().item()
|
||||
for _i in range(200):
|
||||
for scheduler in schedulers:
|
||||
if isinstance(scheduler, PlateauLRScheduler):
|
||||
val_loss = fn()
|
||||
scheduler.step(val_loss)
|
||||
else:
|
||||
scheduler.step()
|
||||
optimizer.step(fn)
|
||||
|
||||
assert fn().item() < initial_value
|
||||
|
||||
|
||||
def _test_state_dict(weight, bias, input, constructor):
|
||||
weight = Variable(weight, requires_grad=True)
|
||||
bias = Variable(bias, requires_grad=True)
|
||||
input = Variable(input)
|
||||
|
||||
def fn_base(optimizer, weight, bias):
|
||||
optimizer.zero_grad()
|
||||
i = input_cuda if weight.is_cuda else input
|
||||
loss = (weight.mv(i) + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
optimizer = constructor(weight, bias)
|
||||
fn = functools.partial(fn_base, optimizer, weight, bias)
|
||||
|
||||
# Prime the optimizer
|
||||
for _i in range(20):
|
||||
optimizer.step(fn)
|
||||
# Clone the weights and construct new optimizer for them
|
||||
weight_c = Variable(weight.data.clone(), requires_grad=True)
|
||||
bias_c = Variable(bias.data.clone(), requires_grad=True)
|
||||
optimizer_c = constructor(weight_c, bias_c)
|
||||
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
|
||||
# Load state dict
|
||||
state_dict = deepcopy(optimizer.state_dict())
|
||||
state_dict_c = deepcopy(optimizer.state_dict())
|
||||
optimizer_c.load_state_dict(state_dict_c)
|
||||
|
||||
# Run both optimizations in parallel
|
||||
for _i in range(20):
|
||||
optimizer.step(fn)
|
||||
optimizer_c.step(fn_c)
|
||||
#assert torch.equal(weight, weight_c)
|
||||
#assert torch.equal(bias, bias_c)
|
||||
torch_tc.assertEqual(weight, weight_c)
|
||||
torch_tc.assertEqual(bias, bias_c)
|
||||
# Make sure state dict wasn't modified
|
||||
torch_tc.assertEqual(state_dict, state_dict_c)
|
||||
# Make sure state dict is deterministic with equal but not identical parameters
|
||||
torch_tc.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
|
||||
# Make sure repeated parameters have identical representation in state dict
|
||||
optimizer_c.param_groups.extend(optimizer_c.param_groups)
|
||||
torch_tc.assertEqual(optimizer.state_dict()['param_groups'][-1], optimizer_c.state_dict()['param_groups'][-1])
|
||||
|
||||
# Check that state dict can be loaded even when we cast parameters
|
||||
# to a different type and move to a different device.
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
input_cuda = Variable(input.data.float().cuda())
|
||||
weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
|
||||
bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
|
||||
optimizer_cuda = constructor(weight_cuda, bias_cuda)
|
||||
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda)
|
||||
|
||||
state_dict = deepcopy(optimizer.state_dict())
|
||||
state_dict_c = deepcopy(optimizer.state_dict())
|
||||
optimizer_cuda.load_state_dict(state_dict_c)
|
||||
|
||||
# Make sure state dict wasn't modified
|
||||
torch_tc.assertEqual(state_dict, state_dict_c)
|
||||
|
||||
for _i in range(20):
|
||||
optimizer.step(fn)
|
||||
optimizer_cuda.step(fn_cuda)
|
||||
torch_tc.assertEqual(weight, weight_cuda)
|
||||
torch_tc.assertEqual(bias, bias_cuda)
|
||||
|
||||
# validate deepcopy() copies all public attributes
|
||||
def getPublicAttr(obj):
|
||||
return set(k for k in obj.__dict__ if not k.startswith('_'))
|
||||
|
||||
assert getPublicAttr(optimizer) == getPublicAttr(deepcopy(optimizer))
|
||||
|
||||
|
||||
def _test_basic_cases(constructor, scheduler_constructors=None):
|
||||
if scheduler_constructors is None:
|
||||
scheduler_constructors = []
|
||||
_test_state_dict(
|
||||
torch.randn(10, 5),
|
||||
torch.randn(10),
|
||||
torch.randn(5),
|
||||
constructor
|
||||
)
|
||||
_test_basic_cases_template(
|
||||
torch.randn(10, 5),
|
||||
torch.randn(10),
|
||||
torch.randn(5),
|
||||
constructor,
|
||||
scheduler_constructors
|
||||
)
|
||||
# non-contiguous parameters
|
||||
_test_basic_cases_template(
|
||||
torch.randn(10, 5, 2)[..., 0],
|
||||
torch.randn(10, 2)[..., 0],
|
||||
torch.randn(5),
|
||||
constructor,
|
||||
scheduler_constructors
|
||||
)
|
||||
# CUDA
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
_test_basic_cases_template(
|
||||
torch.randn(10, 5).cuda(),
|
||||
torch.randn(10).cuda(),
|
||||
torch.randn(5).cuda(),
|
||||
constructor,
|
||||
scheduler_constructors
|
||||
)
|
||||
|
||||
|
||||
def _test_model(optimizer, params, device=torch.device('cpu')):
|
||||
weight = torch.tensor(
|
||||
[[-0.2109, -0.4976], [-0.1413, -0.3420], [-0.2524, 0.6976]],
|
||||
device=device, requires_grad=True)
|
||||
bias = torch.tensor([-0.1085, -0.2979, 0.6892], device=device, requires_grad=True)
|
||||
weight2 = torch.tensor([[-0.0508, -0.3941, -0.2843]], device=device, requires_grad=True)
|
||||
bias2 = torch.tensor([-0.0711], device=device, requires_grad=True)
|
||||
input = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], device=device).reshape(3, 2)
|
||||
|
||||
model = torch.nn.Sequential(torch.nn.Linear(2, 3),
|
||||
torch.nn.Sigmoid(),
|
||||
torch.nn.Linear(3, 1),
|
||||
torch.nn.Sigmoid())
|
||||
model.to(device)
|
||||
|
||||
pretrained_dict = model.state_dict()
|
||||
pretrained_dict['0.weight'] = weight
|
||||
pretrained_dict['0.bias'] = bias
|
||||
pretrained_dict['2.weight'] = weight2
|
||||
pretrained_dict['2.bias'] = bias2
|
||||
model.load_state_dict(pretrained_dict)
|
||||
|
||||
optimizer = create_optimizer_v2(model, opt=optimizer, **params)
|
||||
|
||||
prev_loss = float('inf')
|
||||
for i in range(20):
|
||||
optimizer.zero_grad()
|
||||
output = model(input)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
loss = loss.item()
|
||||
assert loss < prev_loss
|
||||
prev_loss = loss
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def rosenbrock(tensor):
|
||||
x, y = tensor
|
||||
return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2
|
||||
|
||||
|
||||
def drosenbrock(tensor):
|
||||
x, y = tensor
|
||||
return torch.tensor((-400 * x * (y - x ** 2) - 2 * (1 - x), 200 * (y - x ** 2)))
|
||||
|
||||
|
||||
def _test_rosenbrock(constructor, scheduler_constructors=None):
|
||||
if scheduler_constructors is None:
|
||||
scheduler_constructors = []
|
||||
params_t = torch.tensor([1.5, 1.5])
|
||||
|
||||
params = Variable(params_t, requires_grad=True)
|
||||
optimizer = constructor([params])
|
||||
schedulers = []
|
||||
for scheduler_constructor in scheduler_constructors:
|
||||
schedulers.append(scheduler_constructor(optimizer))
|
||||
|
||||
solution = torch.tensor([1, 1])
|
||||
initial_dist = params.data.dist(solution)
|
||||
|
||||
def eval(params, w):
|
||||
# Depending on w, provide only the x or y gradient
|
||||
optimizer.zero_grad()
|
||||
loss = rosenbrock(params)
|
||||
loss.backward()
|
||||
grad = drosenbrock(params.data)
|
||||
# NB: We torture test the optimizer by returning an
|
||||
# uncoalesced sparse tensor
|
||||
if w:
|
||||
i = torch.LongTensor([[0, 0]])
|
||||
x = grad[0]
|
||||
v = torch.tensor([x / 4., x - x / 4.])
|
||||
else:
|
||||
i = torch.LongTensor([[1, 1]])
|
||||
y = grad[1]
|
||||
v = torch.tensor([y - y / 4., y / 4.])
|
||||
x = torch.sparse.DoubleTensor(i, v, torch.Size([2])).to(dtype=v.dtype)
|
||||
with torch.no_grad():
|
||||
params.grad = x.to_dense()
|
||||
return loss
|
||||
|
||||
for i in range(2000):
|
||||
# Do cyclic coordinate descent
|
||||
w = i % 2
|
||||
optimizer.step(functools.partial(eval, params, w))
|
||||
for scheduler in schedulers:
|
||||
if isinstance(scheduler, PlateauLRScheduler):
|
||||
scheduler.step(rosenbrock(params))
|
||||
else:
|
||||
scheduler.step()
|
||||
|
||||
torch_tc.assertLessEqual(params.data.dist(solution), initial_dist)
|
||||
|
||||
|
||||
def _build_params_dict(weight, bias, **kwargs):
|
||||
return [{'params': [weight]}, dict(params=[bias], **kwargs)]
|
||||
|
||||
|
||||
def _build_params_dict_single(weight, bias, **kwargs):
|
||||
return [dict(params=bias, **kwargs)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
|
||||
def test_sgd(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=1e-2),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-2),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
|
||||
)
|
||||
# _test_basic_cases(
|
||||
# lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3),
|
||||
# [lambda opt: StepLR(opt, gamma=0.9, step_size=10)]
|
||||
# )
|
||||
# _test_basic_cases(
|
||||
# lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3),
|
||||
# [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="linear")]
|
||||
# )
|
||||
# _test_basic_cases(
|
||||
# lambda weight, bias: optimizer([weight, bias], lr=1e-3),
|
||||
# [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant")]
|
||||
# )
|
||||
# _test_basic_cases(
|
||||
# lambda weight, bias: optimizer([weight, bias], lr=1e-3),
|
||||
# [lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||||
# lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4)]
|
||||
# )
|
||||
# _test_basic_cases(
|
||||
# lambda weight, bias: optimizer([weight, bias], lr=1e-3),
|
||||
# [lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||||
# lambda opt: ReduceLROnPlateau(opt)]
|
||||
# )
|
||||
# _test_basic_cases(
|
||||
# lambda weight, bias: optimizer([weight, bias], lr=1e-3),
|
||||
# [lambda opt: StepLR(opt, gamma=0.99, step_size=10),
|
||||
# lambda opt: ExponentialLR(opt, gamma=0.99),
|
||||
# lambda opt: ReduceLROnPlateau(opt)]
|
||||
# )
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1, weight_decay=1)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=1e-3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['adamw', 'adam', 'nadam', 'adamax'])
|
||||
def test_adam(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=5e-2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['adabelief'])
|
||||
def test_adabelief(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=5e-2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['radam', 'radabelief'])
|
||||
def test_rectified(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=1e-3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['adadelta', 'adagrad'])
|
||||
def test_adaother(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-1)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=5e-2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['adafactor'])
|
||||
def test_adafactor(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(_build_params_dict_single(weight, bias), optimizer)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=5e-2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['lamb'])
|
||||
def test_lamb(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=1e-3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
|
||||
def test_madgrad(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=1e-2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['novograd'])
|
||||
def test_novograd(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=1e-3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['rmsprop', 'rmsproptf'])
|
||||
def test_rmsprop(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=1e-2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['adamp'])
|
||||
def test_adamp(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=5e-2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['sgdp'])
|
||||
def test_sgdp(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=1e-3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['lookahead_sgd', 'lookahead_momentum'])
|
||||
def test_lookahead_sgd(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['lookahead_adamw', 'lookahead_adam'])
|
||||
def test_lookahead_adam(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['lookahead_radam'])
|
||||
def test_lookahead_radam(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-4)
|
||||
)
|
||||
|
@ -0,0 +1,190 @@
|
||||
""" PyTorch MADGRAD optimizer
|
||||
|
||||
MADGRAD: https://arxiv.org/abs/2101.11075
|
||||
|
||||
Code from: https://github.com/facebookresearch/madgrad
|
||||
"""
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.optim
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.optim.optimizer import _params_t
|
||||
else:
|
||||
_params_t = Any
|
||||
|
||||
|
||||
class MADGRAD(torch.optim.Optimizer):
|
||||
"""
|
||||
MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
|
||||
Optimization.
|
||||
|
||||
.. _MADGRAD: https://arxiv.org/abs/2101.11075
|
||||
|
||||
MADGRAD is a general purpose optimizer that can be used in place of SGD or
|
||||
Adam may converge faster and generalize better. Currently GPU-only.
|
||||
Typically, the same learning rate schedule that is used for SGD or Adam may
|
||||
be used. The overall learning rate is not comparable to either method and
|
||||
should be determined by a hyper-parameter sweep.
|
||||
|
||||
MADGRAD requires less weight decay than other methods, often as little as
|
||||
zero. Momentum values used for SGD or Adam's beta1 should work here also.
|
||||
|
||||
On sparse problems both weight_decay and momentum should be set to 0.
|
||||
|
||||
Arguments:
|
||||
params (iterable):
|
||||
Iterable of parameters to optimize or dicts defining parameter groups.
|
||||
lr (float):
|
||||
Learning rate (default: 1e-2).
|
||||
momentum (float):
|
||||
Momentum value in the range [0,1) (default: 0.9).
|
||||
weight_decay (float):
|
||||
Weight decay, i.e. a L2 penalty (default: 0).
|
||||
eps (float):
|
||||
Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: _params_t,
|
||||
lr: float = 1e-2,
|
||||
momentum: float = 0.9,
|
||||
weight_decay: float = 0,
|
||||
eps: float = 1e-6,
|
||||
decoupled_decay: bool = False,
|
||||
):
|
||||
if momentum < 0 or momentum >= 1:
|
||||
raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
|
||||
if lr <= 0:
|
||||
raise ValueError(f"Learning rate {lr} must be positive")
|
||||
if weight_decay < 0:
|
||||
raise ValueError(f"Weight decay {weight_decay} must be non-negative")
|
||||
if eps < 0:
|
||||
raise ValueError(f"Eps must be non-negative")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@property
|
||||
def supports_memory_efficient_fp16(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_flat_params(self) -> bool:
|
||||
return True
|
||||
|
||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
# step counter must be stored in state to ensure correct behavior under
|
||||
# optimizer sharding
|
||||
if 'k' not in self.state:
|
||||
self.state['k'] = torch.tensor([0], dtype=torch.long)
|
||||
k = self.state['k'].item()
|
||||
|
||||
for group in self.param_groups:
|
||||
eps = group["eps"]
|
||||
lr = group["lr"] + eps
|
||||
weight_decay = group["weight_decay"]
|
||||
momentum = group["momentum"]
|
||||
|
||||
ck = 1 - momentum
|
||||
lamb = lr * math.pow(k + 1, 0.5)
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
state = self.state[p]
|
||||
|
||||
if "grad_sum_sq" not in state:
|
||||
state["grad_sum_sq"] = torch.zeros_like(p.data).detach()
|
||||
state["s"] = torch.zeros_like(p.data).detach()
|
||||
if momentum != 0:
|
||||
state["x0"] = torch.clone(p.data).detach()
|
||||
|
||||
if momentum != 0.0 and grad.is_sparse:
|
||||
raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
|
||||
|
||||
grad_sum_sq = state["grad_sum_sq"]
|
||||
s = state["s"]
|
||||
|
||||
# Apply weight decay
|
||||
if weight_decay != 0:
|
||||
if group['decoupled_decay']:
|
||||
p.data.mul_(1.0 - group['lr'] * weight_decay)
|
||||
else:
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
|
||||
grad.add_(p.data, alpha=weight_decay)
|
||||
|
||||
if grad.is_sparse:
|
||||
grad = grad.coalesce()
|
||||
grad_val = grad._values()
|
||||
|
||||
p_masked = p.sparse_mask(grad)
|
||||
grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
|
||||
s_masked = s.sparse_mask(grad)
|
||||
|
||||
# Compute x_0 from other known quantities
|
||||
rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
|
||||
x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1)
|
||||
|
||||
# Dense + sparse op
|
||||
grad_sq = grad * grad
|
||||
grad_sum_sq.add_(grad_sq, alpha=lamb)
|
||||
grad_sum_sq_masked.add_(grad_sq, alpha=lamb)
|
||||
|
||||
rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
|
||||
|
||||
s.add_(grad, alpha=lamb)
|
||||
s_masked._values().add_(grad_val, alpha=lamb)
|
||||
|
||||
# update masked copy of p
|
||||
p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
|
||||
# Copy updated masked p to dense p using an add operation
|
||||
p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
|
||||
p.data.add_(p_masked, alpha=-1)
|
||||
else:
|
||||
if momentum == 0:
|
||||
# Compute x_0 from other known quantities
|
||||
rms = grad_sum_sq.pow(1 / 3).add_(eps)
|
||||
x0 = p.data.addcdiv(s, rms, value=1)
|
||||
else:
|
||||
x0 = state["x0"]
|
||||
|
||||
# Accumulate second moments
|
||||
grad_sum_sq.addcmul_(grad, grad, value=lamb)
|
||||
rms = grad_sum_sq.pow(1 / 3).add_(eps)
|
||||
|
||||
# Update s
|
||||
s.data.add_(grad, alpha=lamb)
|
||||
|
||||
# Step
|
||||
if momentum == 0:
|
||||
p.data.copy_(x0.addcdiv(s, rms, value=-1))
|
||||
else:
|
||||
z = x0.addcdiv(s, rms, value=-1)
|
||||
|
||||
# p is a moving average of z
|
||||
p.data.mul_(1 - ck).add_(z, alpha=ck)
|
||||
|
||||
self.state['k'] += 1
|
||||
return loss
|
@ -1,77 +0,0 @@
|
||||
"""NovoGrad Optimizer.
|
||||
Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd
|
||||
Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
|
||||
- https://arxiv.org/abs/1905.11286
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
import math
|
||||
|
||||
|
||||
class NovoGrad(Optimizer):
|
||||
def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super(NovoGrad, self).__init__(params, defaults)
|
||||
self._lr = lr
|
||||
self._beta1 = betas[0]
|
||||
self._beta2 = betas[1]
|
||||
self._eps = eps
|
||||
self._wd = weight_decay
|
||||
self._grad_averaging = grad_averaging
|
||||
|
||||
self._momentum_initialized = False
|
||||
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
if not self._momentum_initialized:
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
state = self.state[p]
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('NovoGrad does not support sparse gradients')
|
||||
|
||||
v = torch.norm(grad)**2
|
||||
m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data
|
||||
state['step'] = 0
|
||||
state['v'] = v
|
||||
state['m'] = m
|
||||
state['grad_ema'] = None
|
||||
self._momentum_initialized = True
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
state = self.state[p]
|
||||
state['step'] += 1
|
||||
|
||||
step, v, m = state['step'], state['v'], state['m']
|
||||
grad_ema = state['grad_ema']
|
||||
|
||||
grad = p.grad.data
|
||||
g2 = torch.norm(grad)**2
|
||||
grad_ema = g2 if grad_ema is None else grad_ema * \
|
||||
self._beta2 + g2 * (1. - self._beta2)
|
||||
grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
|
||||
|
||||
if self._grad_averaging:
|
||||
grad *= (1. - self._beta1)
|
||||
|
||||
g2 = torch.norm(grad)**2
|
||||
v = self._beta2*v + (1. - self._beta2)*g2
|
||||
m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data)
|
||||
bias_correction1 = 1 - self._beta1 ** step
|
||||
bias_correction2 = 1 - self._beta2 ** step
|
||||
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
state['v'], state['m'] = v, m
|
||||
state['grad_ema'] = grad_ema
|
||||
p.data.add_(-step_size, m)
|
||||
return loss
|
Loading…
Reference in new issue