From a457b6d14d3341ecd5323dd98a005552cbc6bc7f Mon Sep 17 00:00:00 2001 From: jim4399266 <32267011+jim4399266@users.noreply.github.com> Date: Fri, 12 Nov 2021 10:47:55 +0800 Subject: [PATCH] Update efficientnet.py add a method to get the path of local .pth file(weights) --- timm/models/efficientnet.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 3d50b704..0c6dddf1 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -560,6 +560,16 @@ def _create_effnet(variant, pretrained=False, **kwargs): features_only = True kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool') model_cls = EfficientNetFeatures + + +# default_cfg = default_cfgs[variant] + if kwargs.get('pth_path', None): + # if use local weights and weights are existed + if os.path.exists(kwargs['pth_path']): + default_cfgs[variant].update({f'pth_path': kwargs.pop('pth_path')}) + else: + print(f'Didn\'t find weights file:{kwargs["pth_path"]}') + model = build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant],