diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 842d18f9..06788d2e 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -44,7 +44,7 @@ def param_groups_weight_decay( if not param.requires_grad: continue - if param.ndim or name.endswith(".bias") or name in no_weight_decay_list: + if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list: no_decay.append(param) else: decay.append(param)