|
|
|
@ -57,12 +57,13 @@ model_cfgs = dict(
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url=''):
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
|
return {
|
|
|
|
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
|
|
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
|
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
|
|
|
|
**kwargs
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -84,12 +85,16 @@ default_cfgs = dict(
|
|
|
|
|
regnety_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'),
|
|
|
|
|
regnety_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth'),
|
|
|
|
|
regnety_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth'),
|
|
|
|
|
regnety_032=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth'),
|
|
|
|
|
regnety_032=_cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth',
|
|
|
|
|
crop_pct=1.0, test_input_size=(3, 288, 288)),
|
|
|
|
|
regnety_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'),
|
|
|
|
|
regnety_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'),
|
|
|
|
|
regnety_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'),
|
|
|
|
|
regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'),
|
|
|
|
|
regnety_160=_cfg(url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth'),
|
|
|
|
|
regnety_160=_cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository
|
|
|
|
|
crop_pct=1.0, test_input_size=(3, 288, 288)),
|
|
|
|
|
regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -328,11 +333,20 @@ class RegNet(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _filter_fn(state_dict):
|
|
|
|
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
|
|
|
|
if 'model' in state_dict:
|
|
|
|
|
# For DeiT trained regnety_160 pretraiend model
|
|
|
|
|
state_dict = state_dict['model']
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_regnet(variant, pretrained, **kwargs):
|
|
|
|
|
return build_model_with_cfg(
|
|
|
|
|
RegNet, variant, pretrained,
|
|
|
|
|
default_cfg=default_cfgs[variant],
|
|
|
|
|
model_cfg=model_cfgs[variant],
|
|
|
|
|
pretrained_filter_fn=_filter_fn,
|
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|