diff --git a/timm/models/layers/adaptive_avgmax_pool.py b/timm/models/layers/adaptive_avgmax_pool.py index c1db890b..5838981c 100644 --- a/timm/models/layers/adaptive_avgmax_pool.py +++ b/timm/models/layers/adaptive_avgmax_pool.py @@ -70,10 +70,11 @@ class AdaptiveCatAvgMaxPool2d(nn.Module): class SelectAdaptivePool2d(nn.Module): """Selectable global pooling layer with dynamic input kernel size """ - def __init__(self, output_size=1, pool_type='avg'): + def __init__(self, output_size=1, pool_type='avg', flatten=False): super(SelectAdaptivePool2d, self).__init__() self.output_size = output_size self.pool_type = pool_type + self.flatten = flatten if pool_type == 'avgmax': self.pool = AdaptiveAvgMaxPool2d(output_size) elif pool_type == 'catavgmax': @@ -86,7 +87,10 @@ class SelectAdaptivePool2d(nn.Module): self.pool = nn.AdaptiveAvgPool2d(output_size) def forward(self, x): - return self.pool(x) + x = self.pool(x) + if self.flatten: + x = x.flatten(1) + return x def feat_mult(self): return adaptive_pool_feat_mult(self.pool_type) diff --git a/timm/models/layers/anti_aliasing.py b/timm/models/layers/anti_aliasing.py index a1f7535a..fd6457bf 100644 --- a/timm/models/layers/anti_aliasing.py +++ b/timm/models/layers/anti_aliasing.py @@ -5,13 +5,14 @@ import torch.nn.functional as F class AntiAliasDownsampleLayer(nn.Module): - def __init__(self, remove_aa_jit: bool = False, filt_size: int = 3, stride: int = 2, - channels: int = 0): + def __init__(self, no_jit: bool = False, filt_size: int = 3, stride: int = 2, channels: int = 0): super(AntiAliasDownsampleLayer, self).__init__() - if not remove_aa_jit: - self.op = DownsampleJIT(filt_size, stride, channels) - else: + if no_jit: self.op = Downsample(filt_size, stride, channels) + else: + self.op = DownsampleJIT(filt_size, stride, channels) + + # FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls def forward(self, x): return self.op(x) @@ -23,20 +24,21 @@ class DownsampleJIT(object): self.stride = stride self.filt_size = filt_size self.channels = channels - assert self.filt_size == 3 assert stride == 2 - a = torch.tensor([1., 2., 1.]) + self.filt = {} # lazy init by device for DataParallel compat - filt = (a[:, None] * a[None, :]).clone().detach() + def _create_filter(self, like: torch.Tensor): + filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device) + filt = filt[:, None] * filt[None, :] filt = filt / torch.sum(filt) - self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)).cuda().half() + filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) + return filt def __call__(self, input: torch.Tensor): - if input.dtype != self.filt.dtype: - self.filt = self.filt.float() input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') - return F.conv2d(input_pad, self.filt, stride=2, padding=0, groups=input.shape[1]) + filt = self.filt.get(str(input.device), self._create_filter(input)) + return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1]) class Downsample(nn.Module): @@ -46,11 +48,9 @@ class Downsample(nn.Module): self.stride = stride self.channels = channels - assert self.filt_size == 3 - a = torch.tensor([1., 2., 1.]) - - filt = (a[:, None] * a[None, :]) + filt = torch.tensor([1., 2., 1.]) + filt = filt[:, None] * filt[None, :] filt = filt / torch.sum(filt) # self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) @@ -58,4 +58,4 @@ class Downsample(nn.Module): def forward(self, input): input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') - return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1]) \ No newline at end of file + return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1]) diff --git a/timm/models/layers/space_to_depth.py b/timm/models/layers/space_to_depth.py index 70bf7db9..2c378fe1 100644 --- a/timm/models/layers/space_to_depth.py +++ b/timm/models/layers/space_to_depth.py @@ -28,9 +28,9 @@ class SpaceToDepthJit(object): class SpaceToDepthModule(nn.Module): - def __init__(self, remove_model_jit=False): + def __init__(self, no_jit=False): super().__init__() - if not remove_model_jit: + if not no_jit: self.op = SpaceToDepthJit() else: self.op = SpaceToDepth() diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index dc01e0fb..84b5cb31 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -8,8 +8,9 @@ Original model: https://github.com/mrT23/TResNet from functools import partial import torch import torch.nn as nn +import torch.nn.functional as F from collections import OrderedDict -from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer +from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, SelectAdaptivePool2d from .registry import register_model from .helpers import load_pretrained @@ -27,18 +28,27 @@ def _cfg(url='', **kwargs): 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': (0, 0, 0), 'std': (1, 1, 1), - 'first_conv': 'layer0.conv1', 'classifier': 'head', + 'first_conv': 'layer0.conv1', 'classifier': 'head.fc', **kwargs } default_cfgs = { - 'tresnet_m': - _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/tresnet/tresnet_m_80_8.pth'), - 'tresnet_l': - _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/tresnet/tresnet_l_81_5.pth'), - 'tresnet_xl': - _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/tresnet/tresnet_xl_82_0.pth') + 'tresnet_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_80_8-dbc13962.pth'), + 'tresnet_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_81_5-235b486c.pth'), + 'tresnet_xl': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_82_0-a2d51b00.pth'), + 'tresnet_m_448': _cfg( + input_size=(3, 448, 448), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_448-bc359d10.pth'), + 'tresnet_l_448': _cfg( + input_size=(3, 448, 448), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth'), + 'tresnet_xl_448': _cfg( + input_size=(3, 448, 448), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth') } @@ -54,6 +64,9 @@ class FastGlobalAvgPool2d(nn.Module): else: return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) + def feat_mult(self): + return 1 + class FastSEModule(nn.Module): @@ -78,14 +91,15 @@ def IABN2Float(module: nn.Module) -> nn.Module: "If `module` is IABN don't use half precision." if isinstance(module, InPlaceABN): module.float() - for child in module.children(): IABN2Float(child) + for child in module.children(): + IABN2Float(child) return module def conv2d_ABN(ni, nf, stride, activation="leaky_relu", kernel_size=3, activation_param=1e-2, groups=1): return nn.Sequential( - nn.Conv2d(ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups, - bias=False), + nn.Conv2d( + ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups, bias=False), InPlaceABN(num_features=nf, activation=activation, activation_param=activation_param) ) @@ -101,8 +115,9 @@ class BasicBlock(nn.Module): if anti_alias_layer is None: self.conv1 = conv2d_ABN(inplanes, planes, stride=2, activation_param=1e-3) else: - self.conv1 = nn.Sequential(conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3), - anti_alias_layer(channels=planes, filt_size=3, stride=2)) + self.conv1 = nn.Sequential( + conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3), + anti_alias_layer(channels=planes, filt_size=3, stride=2)) self.conv2 = conv2d_ABN(planes, planes, stride=1, activation="identity") self.relu = nn.ReLU(inplace=True) @@ -120,12 +135,11 @@ class BasicBlock(nn.Module): out = self.conv1(x) out = self.conv2(out) - if self.se is not None: out = self.se(out) + if self.se is not None: + out = self.se(out) out += residual - out = self.relu(out) - return out @@ -134,22 +148,22 @@ class Bottleneck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None): super(Bottleneck, self).__init__() - self.conv1 = conv2d_ABN(inplanes, planes, kernel_size=1, stride=1, activation="leaky_relu", - activation_param=1e-3) + self.conv1 = conv2d_ABN( + inplanes, planes, kernel_size=1, stride=1, activation="leaky_relu", activation_param=1e-3) if stride == 1: - self.conv2 = conv2d_ABN(planes, planes, kernel_size=3, stride=1, activation="leaky_relu", - activation_param=1e-3) + self.conv2 = conv2d_ABN( + planes, planes, kernel_size=3, stride=1, activation="leaky_relu", activation_param=1e-3) else: if anti_alias_layer is None: - self.conv2 = conv2d_ABN(planes, planes, kernel_size=3, stride=2, activation="leaky_relu", - activation_param=1e-3) + self.conv2 = conv2d_ABN( + planes, planes, kernel_size=3, stride=2, activation="leaky_relu", activation_param=1e-3) else: - self.conv2 = nn.Sequential(conv2d_ABN(planes, planes, kernel_size=3, stride=1, - activation="leaky_relu", activation_param=1e-3), - anti_alias_layer(channels=planes, filt_size=3, stride=2)) + self.conv2 = nn.Sequential( + conv2d_ABN(planes, planes, kernel_size=3, stride=1, activation="leaky_relu", activation_param=1e-3), + anti_alias_layer(channels=planes, filt_size=3, stride=2)) - self.conv3 = conv2d_ABN(planes, planes * self.expansion, kernel_size=1, stride=1, - activation="identity") + self.conv3 = conv2d_ABN( + planes, planes * self.expansion, kernel_size=1, stride=1, activation="identity") self.relu = nn.ReLU(inplace=True) self.downsample = downsample @@ -166,7 +180,8 @@ class Bottleneck(nn.Module): out = self.conv1(x) out = self.conv2(out) - if self.se is not None: out = self.se(out) + if self.se is not None: + out = self.se(out) out = self.conv3(out) out = out + residual # no inplace @@ -176,29 +191,32 @@ class Bottleneck(nn.Module): class TResNet(nn.Module): - def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, remove_aa_jit=False): + def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False, + global_pool='avg', drop_rate=0.): if not has_iabn: - raise " For TResNet models, please install InplaceABN: 'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11' " - + raise ImportError( + "For TResNet models, please install InplaceABN: " + "'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'") + self.num_classes = num_classes + self.drop_rate = drop_rate super(TResNet, self).__init__() # JIT layers space_to_depth = SpaceToDepthModule() - anti_alias_layer = partial(AntiAliasDownsampleLayer, remove_aa_jit=remove_aa_jit) - global_pool_layer = FastGlobalAvgPool2d(flatten=True) + anti_alias_layer = partial(AntiAliasDownsampleLayer, no_jit=no_aa_jit) # TResnet stages self.inplanes = int(64 * width_factor) self.planes = int(64 * width_factor) conv1 = conv2d_ABN(in_chans * 16, self.planes, stride=1, kernel_size=3) - layer1 = self._make_layer(BasicBlock, self.planes, layers[0], stride=1, use_se=True, - anti_alias_layer=anti_alias_layer) # 56x56 - layer2 = self._make_layer(BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True, - anti_alias_layer=anti_alias_layer) # 28x28 - layer3 = self._make_layer(Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True, - anti_alias_layer=anti_alias_layer) # 14x14 - layer4 = self._make_layer(Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False, - anti_alias_layer=anti_alias_layer) # 7x7 + layer1 = self._make_layer( + BasicBlock, self.planes, layers[0], stride=1, use_se=True, anti_alias_layer=anti_alias_layer) # 56x56 + layer2 = self._make_layer( + BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True, anti_alias_layer=anti_alias_layer) # 28x28 + layer3 = self._make_layer( + Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True, anti_alias_layer=anti_alias_layer) # 14x14 + layer4 = self._make_layer( + Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False, anti_alias_layer=anti_alias_layer) # 7x7 # body self.body = nn.Sequential(OrderedDict([ @@ -210,11 +228,10 @@ class TResNet(nn.Module): ('layer4', layer4)])) # head - self.embeddings = [] - self.global_pool = nn.Sequential(OrderedDict([('global_pool_layer', global_pool_layer)])) self.num_features = (self.planes * 8) * Bottleneck.expansion - fc = nn.Linear(self.num_features, num_classes) - self.head = nn.Sequential(OrderedDict([('fc', fc)])) + self.global_pool = None + self.head = None + self.reset_classifier(num_classes, global_pool) # model initilization for m in self.modules(): @@ -239,54 +256,104 @@ class TResNet(nn.Module): if stride == 2: # avg pooling before 1x1 conv layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False)) - layers += [conv2d_ABN(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, - activation="identity")] + layers += [conv2d_ABN( + self.inplanes, planes * block.expansion, kernel_size=1, stride=1, activation="identity")] downsample = nn.Sequential(*layers) layers = [] - layers.append(block(self.inplanes, planes, stride, downsample, use_se=use_se, - anti_alias_layer=anti_alias_layer)) + layers.append(block( + self.inplanes, planes, stride, downsample, use_se=use_se, anti_alias_layer=anti_alias_layer)) self.inplanes = planes * block.expansion - for i in range(1, blocks): layers.append( - block(self.inplanes, planes, use_se=use_se, anti_alias_layer=anti_alias_layer)) + for i in range(1, blocks): + layers.append( + block(self.inplanes, planes, use_se=use_se, anti_alias_layer=anti_alias_layer)) return nn.Sequential(*layers) - def forward(self, x): - x = self.body(x) - self.embeddings = self.global_pool(x) - logits = self.head(self.embeddings) - return logits + def get_classifier(self): + return self.head.fc + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + if global_pool == 'avg': + self.global_pool = FastGlobalAvgPool2d(flatten=True) + else: + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + self.head = None + if num_classes: + self.head = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes))])) -def filter_fn(input): - return input['model'] + def forward_features(self, x): + return self.body(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + x = self.head(x) + return x @register_model def tresnet_m(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['tresnet_m'] - model = TResNet(layers=[3, 4, 11, 3], num_classes=num_classes, in_chans=in_chans) + model = TResNet(layers=[3, 4, 11, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=filter_fn) + load_pretrained(model, default_cfg, num_classes, in_chans) return model @register_model def tresnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['tresnet_l'] - model = TResNet(layers=[4, 5, 18, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.2) + model = TResNet( + layers=[4, 5, 18, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.2, **kwargs) model.default_cfg = default_cfg if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=filter_fn) + load_pretrained(model, default_cfg, num_classes, in_chans) return model @register_model def tresnet_xl(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['tresnet_xl'] - model = TResNet(layers=[4, 5, 24, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.3) + model = TResNet( + layers=[4, 5, 24, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.3, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def tresnet_m_448(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + default_cfg = default_cfgs['tresnet_m_448'] + model = TResNet(layers=[3, 4, 11, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def tresnet_l_448(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + default_cfg = default_cfgs['tresnet_l_448'] + model = TResNet( + layers=[4, 5, 18, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.2, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def tresnet_xl_448(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + default_cfg = default_cfgs['tresnet_xl_448'] + model = TResNet( + layers=[4, 5, 24, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.3, **kwargs) model.default_cfg = default_cfg if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=filter_fn) + load_pretrained(model, default_cfg, num_classes, in_chans) return model