better densenet121 and densenetblur121d weights

pull/155/head
Ross Wightman 5 years ago
parent 7be299504f
commit e78daf586a

@ -30,13 +30,16 @@ def _cfg(url=''):
default_cfgs = { default_cfgs = {
'densenet121': _cfg(url='https://download.pytorch.org/models/densenet121-a639ec97.pth'), 'densenet121': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenet121_ra-50efcf5c.pth'),
'densenet121d': _cfg(url=''), 'densenet121d': _cfg(url=''),
'densenet121tn': _cfg(url=''), 'densenetblur121d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenetblur121d_ra-100dcfbc.pth'),
'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'), 'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'),
'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'), 'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'),
'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'), 'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'),
'densenet264': _cfg(url=''), 'densenet264': _cfg(url=''),
'tv_densenet121': _cfg(url='https://download.pytorch.org/models/densenet121-a639ec97.pth'),
} }
@ -160,7 +163,8 @@ class DenseNet(nn.Module):
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem_type='', def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem_type='',
num_classes=1000, in_chans=3, global_pool='avg', num_classes=1000, in_chans=3, global_pool='avg',
norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False): norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False,
aa_stem_only=True):
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
super(DenseNet, self).__init__() super(DenseNet, self).__init__()
@ -209,10 +213,11 @@ class DenseNet(nn.Module):
) )
self.features.add_module('denseblock%d' % (i + 1), block) self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate num_features = num_features + num_layers * growth_rate
transition_aa_layer = None if aa_stem_only else aa_layer
if i != len(block_config) - 1: if i != len(block_config) - 1:
trans = DenseTransition( trans = DenseTransition(
num_input_features=num_features, num_output_features=num_features // 2, num_input_features=num_features, num_output_features=num_features // 2,
norm_layer=norm_layer) norm_layer=norm_layer, aa_layer=transition_aa_layer)
self.features.add_module('transition%d' % (i + 1), trans) self.features.add_module('transition%d' % (i + 1), trans)
num_features = num_features // 2 num_features = num_features // 2
@ -310,7 +315,7 @@ def densenetblur121d(pretrained=False, **kwargs):
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
""" """
model = _densenet( model = _densenet(
'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep', 'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep',
aa_layer=BlurPool2d, **kwargs) aa_layer=BlurPool2d, **kwargs)
return model return model
@ -326,17 +331,6 @@ def densenet121d(pretrained=False, **kwargs):
return model return model
@register_model
def densenet121tn(pretrained=False, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _densenet(
'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep_tiered_narrow',
pretrained=pretrained, **kwargs)
return model
@register_model @register_model
def densenet121d_evob(pretrained=False, **kwargs): def densenet121d_evob(pretrained=False, **kwargs):
r"""Densenet-121 model from r"""Densenet-121 model from
@ -414,3 +408,13 @@ def densenet264(pretrained=False, **kwargs):
model = _densenet( model = _densenet(
'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs) 'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs)
return model return model
@register_model
def tv_densenet121(pretrained=False, **kwargs):
r"""Densenet-121 model with original Torchvision weights, from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _densenet(
'tv_densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs)
return model

Loading…
Cancel
Save