From c468c47a9cecf9a3cb872dc7e24d203086b0f3b2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 1 Apr 2021 16:41:04 -0700 Subject: [PATCH] Add regnety_160 weights from DeiT teacher model, update that and my regnety_032 weights to use higher test size. --- timm/models/regnet.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 26d8650b..3b7dba52 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -57,12 +57,13 @@ model_cfgs = dict( ) -def _cfg(url=''): +def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'stem.conv', 'classifier': 'head.fc', + **kwargs } @@ -84,12 +85,16 @@ default_cfgs = dict( regnety_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'), regnety_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth'), regnety_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth'), - regnety_032=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth'), + regnety_032=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth', + crop_pct=1.0, test_input_size=(3, 288, 288)), regnety_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'), regnety_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'), regnety_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'), regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'), - regnety_160=_cfg(url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth'), + regnety_160=_cfg( + url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository + crop_pct=1.0, test_input_size=(3, 288, 288)), regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'), ) @@ -328,11 +333,20 @@ class RegNet(nn.Module): return x +def _filter_fn(state_dict): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + if 'model' in state_dict: + # For DeiT trained regnety_160 pretraiend model + state_dict = state_dict['model'] + return state_dict + + def _create_regnet(variant, pretrained, **kwargs): return build_model_with_cfg( RegNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant], + pretrained_filter_fn=_filter_fn, **kwargs)