Update efficientnet.py

add a method to get the path of local .pth file(weights)
pull/967/head
jim4399266 4 years ago committed by GitHub
parent 65419f60cc
commit a457b6d14d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -560,6 +560,16 @@ def _create_effnet(variant, pretrained=False, **kwargs):
features_only = True
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool')
model_cls = EfficientNetFeatures
# default_cfg = default_cfgs[variant]
if kwargs.get('pth_path', None):
# if use local weights and weights are existed
if os.path.exists(kwargs['pth_path']):
default_cfgs[variant].update({f'pth_path': kwargs.pop('pth_path')})
else:
print(f'Didn\'t find weights file:{kwargs["pth_path"]}')
model = build_model_with_cfg(
model_cls, variant, pretrained,
default_cfg=default_cfgs[variant],

Loading…
Cancel
Save