Add Facebook Research Semi-Supervised and Semi-Weakly Supervised ResNet model weights.

pull/52/head
Ross Wightman 5 years ago
parent a9eb484835
commit b93fcf0708

@ -167,6 +167,65 @@ model_list = [
_entry('ig_resnext101_32x48d', 'ResNeXt-101 32x48d (288x288 Mean-Max Pooling)', '1805.00932', _entry('ig_resnext101_32x48d', 'ResNeXt-101 32x48d (288x288 Mean-Max Pooling)', '1805.00932',
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 8), ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 8),
## Facebook SSL weights
_entry('ssl_resnet18', 'ResNet-18', '1905.00546',
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('ssl_resnet50', 'ResNet-50', '1905.00546',
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('ssl_resnext50_32x4d', 'ResNeXt-50 32x4d', '1905.00546',
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('ssl_resnext101_32x4d', 'ResNeXt-101 32x4d', '1905.00546',
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('ssl_resnext101_32x8d', 'ResNeXt-101 32x8d', '1905.00546',
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('ssl_resnext101_32x16d', 'ResNeXt-101 32x16d', '1905.00546',
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('ssl_resnet50', 'ResNet-50 (288x288 Mean-Max Pooling)', '1905.00546',
ttp=True, args=dict(img_size=288),
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('ssl_resnext50_32x4d', 'ResNeXt-50 32x4d (288x288 Mean-Max Pooling)', '1905.00546',
ttp=True, args=dict(img_size=288),
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('ssl_resnext101_32x4d', 'ResNeXt-101 32x4d (288x288 Mean-Max Pooling)', '1905.00546',
ttp=True, args=dict(img_size=288),
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('ssl_resnext101_32x8d', 'ResNeXt-101 32x8d (288x288 Mean-Max Pooling)', '1905.00546',
ttp=True, args=dict(img_size=288),
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('ssl_resnext101_32x16d', 'ResNeXt-101 32x16d (288x288 Mean-Max Pooling)', '1905.00546',
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 2,
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
## Facebook SWSL weights
_entry('swsl_resnet18', 'ResNet-18', '1905.00546',
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('swsl_resnet50', 'ResNet-50', '1905.00546',
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('swsl_resnext50_32x4d', 'ResNeXt-50 32x4d', '1905.00546',
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('swsl_resnext101_32x4d', 'ResNeXt-101 32x4d', '1905.00546',
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('swsl_resnext101_32x8d', 'ResNeXt-101 32x8d', '1905.00546',
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('swsl_resnext101_32x16d', 'ResNeXt-101 32x16d', '1905.00546'),
_entry('swsl_resnet50', 'ResNet-50 (288x288 Mean-Max Pooling)', '1905.00546',
ttp=True, args=dict(img_size=288),
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('swsl_resnext50_32x4d', 'ResNeXt-50 32x4d (288x288 Mean-Max Pooling)', '1905.00546',
ttp=True, args=dict(img_size=288),
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('swsl_resnext101_32x4d', 'ResNeXt-101 32x4d (288x288 Mean-Max Pooling)', '1905.00546',
ttp=True, args=dict(img_size=288),
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('swsl_resnext101_32x8d', 'ResNeXt-101 32x8d (288x288 Mean-Max Pooling)', '1905.00546',
ttp=True, args=dict(img_size=288),
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
_entry('swsl_resnext101_32x16d', 'ResNeXt-101 32x16d (288x288 Mean-Max Pooling)', '1905.00546',
ttp=True, args=dict(img_size=288), batch_size=BATCH_SIZE // 2,
model_desc='Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/'),
## DLA official impl weights (to remove if sotabench added to source) ## DLA official impl weights (to remove if sotabench added to source)
_entry('dla34', 'DLA-34', '1707.06484'), _entry('dla34', 'DLA-34', '1707.06484'),
_entry('dla46_c', 'DLA-46-C', '1707.06484'), _entry('dla46_c', 'DLA-46-C', '1707.06484'),

@ -57,15 +57,17 @@ def resume_checkpoint(model, checkpoint_path):
raise FileNotFoundError() raise FileNotFoundError()
def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=None): def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None):
if 'url' not in default_cfg or not default_cfg['url']: if cfg is None:
cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']:
logging.warning("Pretrained model URL is invalid, using random initialization.") logging.warning("Pretrained model URL is invalid, using random initialization.")
return return
state_dict = model_zoo.load_url(default_cfg['url'], progress=False) state_dict = model_zoo.load_url(cfg['url'], progress=False)
if in_chans == 1: if in_chans == 1:
conv1_name = default_cfg['first_conv'] conv1_name = cfg['first_conv']
logging.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name) logging.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
conv1_weight = state_dict[conv1_name + '.weight'] conv1_weight = state_dict[conv1_name + '.weight']
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True) state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
@ -73,14 +75,14 @@ def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=
assert False, "Invalid in_chans for pretrained weights" assert False, "Invalid in_chans for pretrained weights"
strict = True strict = True
classifier_name = default_cfg['classifier'] classifier_name = cfg['classifier']
if num_classes == 1000 and default_cfg['num_classes'] == 1001: if num_classes == 1000 and cfg['num_classes'] == 1001:
# special case for imagenet trained models with extra background class in pretrained weights # special case for imagenet trained models with extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight'] classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[1:] state_dict[classifier_name + '.weight'] = classifier_weight[1:]
classifier_bias = state_dict[classifier_name + '.bias'] classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[1:] state_dict[classifier_name + '.bias'] = classifier_bias[1:]
elif num_classes != default_cfg['num_classes']: elif num_classes != cfg['num_classes']:
# completely discard fully connected for all other differences between pretrained and created model # completely discard fully connected for all other differences between pretrained and created model
del state_dict[classifier_name + '.weight'] del state_dict[classifier_name + '.weight']
del state_dict[classifier_name + '.bias'] del state_dict[classifier_name + '.bias']

@ -67,6 +67,30 @@ default_cfgs = {
'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'), 'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'),
'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'), 'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'),
'ig_resnext101_32x48d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth'), 'ig_resnext101_32x48d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth'),
'ssl_resnet18': _cfg(
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth'),
'ssl_resnet50': _cfg(
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth'),
'ssl_resnext50_32x4d': _cfg(
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth'),
'ssl_resnext101_32x4d': _cfg(
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth'),
'ssl_resnext101_32x8d': _cfg(
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth'),
'ssl_resnext101_32x16d': _cfg(
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth'),
'swsl_resnet18': _cfg(
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth'),
'swsl_resnet50': _cfg(
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth'),
'swsl_resnext50_32x4d': _cfg(
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth'),
'swsl_resnext101_32x4d': _cfg(
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth'),
'swsl_resnext101_32x8d': _cfg(
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth'),
'swsl_resnext101_32x16d': _cfg(
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth'),
} }
@ -621,80 +645,218 @@ def tv_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
@register_model @register_model
def ig_resnext101_32x8d(pretrained=True, num_classes=1000, in_chans=3, **kwargs): def ig_resnext101_32x8d(pretrained=True, **kwargs):
"""Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data """Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data
and finetuned on ImageNet from Figure 5 in and finetuned on ImageNet from Figure 5 in
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_ `"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
Args: """
pretrained (bool): load pretrained weights model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
num_classes (int): number of classes for classifier (default: 1000 for pretrained) model.default_cfg = default_cfgs['ig_resnext101_32x8d']
in_chans (int): number of input planes (default: 3 for pretrained / color)
"""
default_cfg = default_cfgs['ig_resnext101_32x8d']
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8,
num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model return model
@register_model @register_model
def ig_resnext101_32x16d(pretrained=True, num_classes=1000, in_chans=3, **kwargs): def ig_resnext101_32x16d(pretrained=True, **kwargs):
"""Constructs a ResNeXt-101 32x16 model pre-trained on weakly-supervised data """Constructs a ResNeXt-101 32x16 model pre-trained on weakly-supervised data
and finetuned on ImageNet from Figure 5 in and finetuned on ImageNet from Figure 5 in
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_ `"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
Args: """
pretrained (bool): load pretrained weights model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
num_classes (int): number of classes for classifier (default: 1000 for pretrained) model.default_cfg = default_cfgs['ig_resnext101_32x16d']
in_chans (int): number of input planes (default: 3 for pretrained / color)
"""
default_cfg = default_cfgs['ig_resnext101_32x16d']
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16,
num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model return model
@register_model @register_model
def ig_resnext101_32x32d(pretrained=True, num_classes=1000, in_chans=3, **kwargs): def ig_resnext101_32x32d(pretrained=True, **kwargs):
"""Constructs a ResNeXt-101 32x32 model pre-trained on weakly-supervised data """Constructs a ResNeXt-101 32x32 model pre-trained on weakly-supervised data
and finetuned on ImageNet from Figure 5 in and finetuned on ImageNet from Figure 5 in
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_ `"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
Args: """
pretrained (bool): load pretrained weights model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=32, **kwargs)
num_classes (int): number of classes for classifier (default: 1000 for pretrained) model.default_cfg = default_cfgs['ig_resnext101_32x32d']
in_chans (int): number of input planes (default: 3 for pretrained / color)
"""
default_cfg = default_cfgs['ig_resnext101_32x32d']
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=32,
num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model return model
@register_model @register_model
def ig_resnext101_32x48d(pretrained=True, num_classes=1000, in_chans=3, **kwargs): def ig_resnext101_32x48d(pretrained=True, **kwargs):
"""Constructs a ResNeXt-101 32x48 model pre-trained on weakly-supervised data """Constructs a ResNeXt-101 32x48 model pre-trained on weakly-supervised data
and finetuned on ImageNet from Figure 5 in and finetuned on ImageNet from Figure 5 in
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_ `"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
Args: """
pretrained (bool): load pretrained weights model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=48, **kwargs)
num_classes (int): number of classes for classifier (default: 1000 for pretrained) model.default_cfg = default_cfgs['ig_resnext101_32x48d']
in_chans (int): number of input planes (default: 3 for pretrained / color)
"""
default_cfg = default_cfgs['ig_resnext101_32x48d']
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=48,
num_classes=1000, in_chans=3, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def ssl_resnet18(pretrained=True, **kwargs):
"""Constructs a semi-supervised ResNet-18 model pre-trained on YFCC100M dataset and finetuned on ImageNet
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
model.default_cfg = default_cfgs['ssl_resnet18']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def ssl_resnet50(pretrained=True, **kwargs):
"""Constructs a semi-supervised ResNet-50 model pre-trained on YFCC100M dataset and finetuned on ImageNet
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
model.default_cfg = default_cfgs['ssl_resnet50']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def ssl_resnext50_32x4d(pretrained=True, **kwargs):
"""Constructs a semi-supervised ResNeXt-50 32x4 model pre-trained on YFCC100M dataset and finetuned on ImageNet
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
model.default_cfg = default_cfgs['ssl_resnext50_32x4d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def ssl_resnext101_32x4d(pretrained=True, **kwargs):
"""Constructs a semi-supervised ResNeXt-101 32x4 model pre-trained on YFCC100M dataset and finetuned on ImageNet
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
model.default_cfg = default_cfgs['ssl_resnext101_32x4d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def ssl_resnext101_32x8d(pretrained=True, **kwargs):
"""Constructs a semi-supervised ResNeXt-101 32x8 model pre-trained on YFCC100M dataset and finetuned on ImageNet
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
model.default_cfg = default_cfgs['ssl_resnext101_32x8d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def ssl_resnext101_32x16d(pretrained=True, **kwargs):
"""Constructs a semi-supervised ResNeXt-101 32x16 model pre-trained on YFCC100M dataset and finetuned on ImageNet
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
model.default_cfg = default_cfgs['ssl_resnext101_32x16d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def swsl_resnet18(pretrained=True, **kwargs):
"""Constructs a semi-weakly supervised Resnet-18 model pre-trained on 1B weakly supervised
image dataset and finetuned on ImageNet.
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
model.default_cfg = default_cfgs['swsl_resnet18']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def swsl_resnet50(pretrained=True, **kwargs):
"""Constructs a semi-weakly supervised ResNet-50 model pre-trained on 1B weakly supervised
image dataset and finetuned on ImageNet.
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
model.default_cfg = default_cfgs['swsl_resnet50']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def swsl_resnext50_32x4d(pretrained=True, **kwargs):
"""Constructs a semi-weakly supervised ResNeXt-50 32x4 model pre-trained on 1B weakly supervised
image dataset and finetuned on ImageNet.
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
model.default_cfg = default_cfgs['swsl_resnext50_32x4d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def swsl_resnext101_32x4d(pretrained=True, **kwargs):
"""Constructs a semi-weakly supervised ResNeXt-101 32x4 model pre-trained on 1B weakly supervised
image dataset and finetuned on ImageNet.
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
model.default_cfg = default_cfgs['swsl_resnext101_32x4d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def swsl_resnext101_32x8d(pretrained=True, **kwargs):
"""Constructs a semi-weakly supervised ResNeXt-101 32x8 model pre-trained on 1B weakly supervised
image dataset and finetuned on ImageNet.
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
model.default_cfg = default_cfgs['swsl_resnext101_32x8d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def swsl_resnext101_32x16d(pretrained=True, **kwargs):
"""Constructs a semi-weakly supervised ResNeXt-101 32x16 model pre-trained on 1B weakly supervised
image dataset and finetuned on ImageNet.
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
model.default_cfg = default_cfgs['swsl_resnext101_32x16d']
if pretrained:
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model return model

Loading…
Cancel
Save