diff --git a/timm/bits/train_setup.py b/timm/bits/train_setup.py index 5aca908f..87c1b5c5 100644 --- a/timm/bits/train_setup.py +++ b/timm/bits/train_setup.py @@ -72,9 +72,10 @@ def setup_model_and_optimizer( 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') if isinstance(optimizer, Callable): - optimizer = optimizer(model=model, **optimizer_cfg) + # FIXME this interface needs to be figured out, model, model and/or parameters, or just parameters? + optimizer = optimizer(model, **optimizer_cfg) else: - optimizer = create_optimizer_v2(model=model, **optimizer_cfg) + optimizer = create_optimizer_v2(model, **optimizer_cfg) updater = create_updater( model=model,