diff --git a/timm/models/_builder.py b/timm/models/_builder.py index 901d7d44..32a35304 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -179,11 +179,11 @@ def load_pretrained( return if filter_fn is not None: - # for backwards compat with filter fn that take one arg, try one first, the two try: - state_dict = filter_fn(state_dict) - except TypeError: state_dict = filter_fn(state_dict, model) + except TypeError as e: + # for backwards compat with filter fn that take one arg + state_dict = filter_fn(state_dict) input_convs = pretrained_cfg.get('first_conv', None) if input_convs is not None and in_chans != 3: