diff --git a/timm/models/densenet.py b/timm/models/densenet.py index c8be8683..4b774c4a 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -8,6 +8,8 @@ from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F +import torch.utils.checkpoint as cp +from torch.jit.annotations import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained @@ -28,53 +30,121 @@ def _cfg(url=''): default_cfgs = { 'densenet121': _cfg(url='https://download.pytorch.org/models/densenet121-a639ec97.pth'), + 'densenet121d': _cfg(url=''), + 'densenet121tn': _cfg(url=''), 'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'), 'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'), 'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'), } -class _DenseLayer(nn.Sequential): - def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): +class _DenseLayer(nn.Module): + def __init__(self, num_input_features, growth_rate, bn_size, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + drop_rate=0., memory_efficient=False): super(_DenseLayer, self).__init__() - self.add_module('norm1', nn.BatchNorm2d(num_input_features)), - self.add_module('relu1', nn.ReLU(inplace=True)), - self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * - growth_rate, kernel_size=1, stride=1, bias=False)), - self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), - self.add_module('relu2', nn.ReLU(inplace=True)), - self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, - kernel_size=3, stride=1, padding=1, bias=False)), - self.drop_rate = drop_rate + self.add_module('norm1', norm_layer(num_input_features)), + self.add_module('relu1', act_layer(inplace=True)), + self.add_module('conv1', nn.Conv2d( + num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)), + self.add_module('norm2', norm_layer(bn_size * growth_rate)), + self.add_module('relu2', act_layer(inplace=True)), + self.add_module('conv2', nn.Conv2d( + bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)), + self.drop_rate = float(drop_rate) + self.memory_efficient = memory_efficient + + def bn_function(self, inputs): + # type: (List[torch.Tensor]) -> torch.Tensor + concated_features = torch.cat(inputs, 1) + bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484 + return bottleneck_output + + # todo: rewrite when torchscript supports any + def any_requires_grad(self, input): + # type: (List[torch.Tensor]) -> bool + for tensor in input: + if tensor.requires_grad: + return True + return False + + @torch.jit.unused # noqa: T484 + def call_checkpoint_bottleneck(self, input): + # type: (List[torch.Tensor]) -> torch.Tensor + def closure(*inputs): + return self.bn_function(*inputs) + + return cp.checkpoint(closure, input) + + @torch.jit._overload_method # noqa: F811 + def forward(self, input): + # type: (List[torch.Tensor]) -> (torch.Tensor) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, input): + # type: (torch.Tensor) -> (torch.Tensor) + pass + + # torchscript does not yet support *args, so we overload method + # allowing it to take either a List[Tensor] or single Tensor + def forward(self, input): # noqa: F811 + if isinstance(input, torch.Tensor): + prev_features = [input] + else: + prev_features = input - def forward(self, x): - new_features = super(_DenseLayer, self).forward(x) + if self.memory_efficient and self.any_requires_grad(prev_features): + if torch.jit.is_scripting(): + raise Exception("Memory Efficient not supported in JIT") + bottleneck_output = self.call_checkpoint_bottleneck(prev_features) + else: + bottleneck_output = self.bn_function(prev_features) + + new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) if self.drop_rate > 0: new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) - return torch.cat([x, new_features], 1) + return new_features -class _DenseBlock(nn.Sequential): - def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): +class _DenseBlock(nn.ModuleDict): + _version = 2 + + def __init__(self, num_layers, num_input_features, bn_size, growth_rate, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, drop_rate=0., memory_efficient=False): super(_DenseBlock, self).__init__() for i in range(num_layers): - layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) + layer = _DenseLayer( + num_input_features + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + act_layer=act_layer, + norm_layer=norm_layer, + drop_rate=drop_rate, + memory_efficient=memory_efficient, + ) self.add_module('denselayer%d' % (i + 1), layer) + def forward(self, init_features): + features = [init_features] + for name, layer in self.items(): + new_features = layer(features) + features.append(new_features) + return torch.cat(features, 1) + class _Transition(nn.Sequential): - def __init__(self, num_input_features, num_output_features): + def __init__(self, num_input_features, num_output_features, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(_Transition, self).__init__() - self.add_module('norm', nn.BatchNorm2d(num_input_features)) - self.add_module('relu', nn.ReLU(inplace=True)) - self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, - kernel_size=1, stride=1, bias=False)) + self.add_module('norm', norm_layer(num_input_features)) + self.add_module('relu', act_layer(inplace=True)) + self.add_module('conv', nn.Conv2d( + num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) class DenseNet(nn.Module): r"""Densenet-BC model class, based on - `"Densely Connected Convolutional Networks" ` + `"Densely Connected Convolutional Networks" `_ Args: growth_rate (int) - how many filters to add each layer (`k` in paper) @@ -84,44 +154,87 @@ class DenseNet(nn.Module): (i.e. bn_size * k features in the bottleneck layer) drop_rate (float) - dropout rate after each dense layer num_classes (int) - number of classification classes + memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, + but slower. Default: *False*. See `"paper" `_ """ - def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), - num_init_features=64, bn_size=4, drop_rate=0, - num_classes=1000, in_chans=3, global_pool='avg'): + def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, + bn_size=4, stem_type='', num_classes=1000, in_chans=3, global_pool='avg', + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0, memory_efficient=False): self.num_classes = num_classes self.drop_rate = drop_rate + deep_stem = 'deep' in stem_type super(DenseNet, self).__init__() # First convolution - self.features = nn.Sequential(OrderedDict([ - ('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), - ('norm0', nn.BatchNorm2d(num_init_features)), - ('relu0', nn.ReLU(inplace=True)), - ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), - ])) + if aa_layer is None: + max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + else: + max_pool = nn.Sequential(*[ + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + aa_layer(channels=self.inplanes, stride=2)]) + if deep_stem: + stem_chs_1 = stem_chs_2 = num_init_features // 2 + if 'tiered' in stem_type: + stem_chs_1 = 3 * (num_init_features // 8) + stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (num_init_features // 8) + self.features = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)), + ('norm0', norm_layer(stem_chs_1)), + ('relu0', act_layer(inplace=True)), + ('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)), + ('norm1', norm_layer(stem_chs_2)), + ('relu1', act_layer(inplace=True)), + ('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)), + ('norm2', norm_layer(num_init_features)), + ('relu2', act_layer(inplace=True)), + ('pool0', max_pool), + ])) + else: + self.features = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), + ('norm0', norm_layer(num_init_features)), + ('relu0', act_layer(inplace=True)), + ('pool0', max_pool), + ])) # Each denseblock num_features = num_init_features for i, num_layers in enumerate(block_config): - block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, - bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) + block = _DenseBlock( + num_layers=num_layers, + num_input_features=num_features, + bn_size=bn_size, + growth_rate=growth_rate, + drop_rate=drop_rate, + memory_efficient=memory_efficient + ) self.features.add_module('denseblock%d' % (i + 1), block) num_features = num_features + num_layers * growth_rate if i != len(block_config) - 1: - trans = _Transition( - num_input_features=num_features, num_output_features=num_features // 2) + trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) self.features.add_module('transition%d' % (i + 1), trans) num_features = num_features // 2 # Final batch norm - self.features.add_module('norm5', nn.BatchNorm2d(num_features)) + self.features.add_module('norm5', norm_layer(num_features)) + self.act = act_layer(inplace=True) # Linear layer self.num_features = num_features self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + # Official init from torch repo. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.constant_(m.bias, 0) + def get_classifier(self): return self.classifier @@ -136,19 +249,20 @@ class DenseNet(nn.Module): def forward_features(self, x): x = self.features(x) - x = F.relu(x, inplace=True) + x = self.act(x) return x def forward(self, x): x = self.forward_features(x) x = self.global_pool(x).flatten(1) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) + # both classifier and block drop? + # if self.drop_rate > 0.: + # x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.classifier(x) return x -def _filter_pretrained(state_dict): +def _filter_torchvision_pretrained(state_dict): pattern = re.compile( r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') @@ -161,57 +275,90 @@ def _filter_pretrained(state_dict): return state_dict +def _densenet(variant, growth_rate, block_config, num_init_features, pretrained, **kwargs): + if kwargs.pop('features_only', False): + assert False, 'Not Implemented' # TODO + load_strict = False + kwargs.pop('num_classes', 0) + model_class = DenseNet + else: + load_strict = True + model_class = DenseNet + default_cfg = default_cfgs[variant] + model = model_class( + growth_rate=growth_rate, block_config=block_config, num_init_features=num_init_features, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained( + model, default_cfg, + num_classes=kwargs.get('num_classes', 0), + in_chans=kwargs.get('in_chans', 3), + filter_fn=_filter_torchvision_pretrained, + strict=load_strict) + return model + + @register_model -def densenet121(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def densenet121(pretrained=False, **kwargs): r"""Densenet-121 model from `"Densely Connected Convolutional Networks" ` """ - default_cfg = default_cfgs['densenet121'] - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained) + model = _densenet( + 'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet121d(pretrained=False, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _densenet( + 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, + stem_type='deep', pretrained=pretrained, **kwargs) return model @register_model -def densenet169(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def densenet121tn(pretrained=False, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _densenet( + 'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, + stem_type='deep_tiered_narrow', pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet169(pretrained=False, **kwargs): r"""Densenet-169 model from `"Densely Connected Convolutional Networks" ` """ - default_cfg = default_cfgs['densenet169'] - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained) + model = _densenet( + 'densenet169', growth_rate=32, block_config=(6, 12, 32, 32), num_init_features=64, + pretrained=pretrained, **kwargs) return model @register_model -def densenet201(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def densenet201(pretrained=False, **kwargs): r"""Densenet-201 model from `"Densely Connected Convolutional Networks" ` """ - default_cfg = default_cfgs['densenet201'] - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained) + model = _densenet( + 'densenet201', growth_rate=32, block_config=(6, 12, 48, 32), num_init_features=64, + pretrained=pretrained, **kwargs) return model @register_model -def densenet161(pretrained=False, num_classes=1000, in_chans=3, **kwargs): +def densenet161(pretrained=False, **kwargs): r"""Densenet-201 model from `"Densely Connected Convolutional Networks" ` """ - default_cfg = default_cfgs['densenet161'] - model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained) + model = _densenet( + 'densenet161', growth_rate=48, block_config=(6, 12, 36, 24), num_init_features=96, + pretrained=pretrained, **kwargs) return model