Another attempt at sgd momentum test passing...

pull/853/head
Ross Wightman 3 years ago
parent 78933122c9
commit 54e90e82a5

@ -317,10 +317,10 @@ def test_sgd(optimizer):
# lambda opt: ReduceLROnPlateau(opt)] # lambda opt: ReduceLROnPlateau(opt)]
# ) # )
_test_basic_cases( _test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1) lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=3e-3, momentum=1)
) )
_test_basic_cases( _test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1, weight_decay=.1) lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=3e-3, momentum=1, weight_decay=.1)
) )
_test_rosenbrock( _test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)

Loading…
Cancel
Save