diff --git a/timm/models/layers/adaptive_avgmax_pool.py b/timm/models/layers/adaptive_avgmax_pool.py index 482c0c01..d2bb9f72 100644 --- a/timm/models/layers/adaptive_avgmax_pool.py +++ b/timm/models/layers/adaptive_avgmax_pool.py @@ -49,6 +49,15 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1): return x +class FastAdaptiveAvgPool2d(nn.Module): + def __init__(self, flatten=False): + super(FastAdaptiveAvgPool2d, self).__init__() + self.flatten = flatten + + def forward(self, x): + return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True) + + class AdaptiveAvgMaxPool2d(nn.Module): def __init__(self, output_size=1): super(AdaptiveAvgMaxPool2d, self).__init__() @@ -70,12 +79,16 @@ class AdaptiveCatAvgMaxPool2d(nn.Module): class SelectAdaptivePool2d(nn.Module): """Selectable global pooling layer with dynamic input kernel size """ - def __init__(self, output_size=1, pool_type='avg', flatten=False): + def __init__(self, output_size=1, pool_type='fast', flatten=False): super(SelectAdaptivePool2d, self).__init__() self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing self.flatten = flatten if pool_type == '': self.pool = nn.Identity() # pass through + elif pool_type == 'fast': + assert output_size == 1 + self.pool = FastAdaptiveAvgPool2d(self.flatten) + self.flatten = False elif pool_type == 'avg': self.pool = nn.AdaptiveAvgPool2d(output_size) elif pool_type == 'avgmax': diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 75b545e5..e371292f 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -14,7 +14,7 @@ import torch.nn as nn import torch.nn.functional as F from .helpers import build_model_with_cfg -from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead +from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule from .registry import register_model __all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] @@ -49,40 +49,6 @@ default_cfgs = { } -class FastGlobalAvgPool2d(nn.Module): - def __init__(self, flatten=False): - super(FastGlobalAvgPool2d, self).__init__() - self.flatten = flatten - - def forward(self, x): - if self.flatten: - return x.mean((2, 3)) - else: - return x.mean((2, 3), keepdim=True) - - def feat_mult(self): - return 1 - - -class FastSEModule(nn.Module): - - def __init__(self, channels, reduction_channels, inplace=True): - super(FastSEModule, self).__init__() - self.avg_pool = FastGlobalAvgPool2d() - self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, padding=0, bias=True) - self.relu = nn.ReLU(inplace=inplace) - self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, padding=0, bias=True) - self.activation = nn.Sigmoid() - - def forward(self, x): - x_se = self.avg_pool(x) - x_se2 = self.fc1(x_se) - x_se2 = self.relu(x_se2) - x_se = self.fc2(x_se2) - x_se = self.activation(x_se) - return x * x_se - - def IABN2Float(module: nn.Module) -> nn.Module: """If `module` is IABN don't use half precision.""" if isinstance(module, InplaceAbn): @@ -119,8 +85,8 @@ class BasicBlock(nn.Module): self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride - reduce_layer_planes = max(planes * self.expansion // 4, 64) - self.se = FastSEModule(planes * self.expansion, reduce_layer_planes) if use_se else None + reduction_chs = max(planes * self.expansion // 4, 64) + self.se = SEModule(planes * self.expansion, reduction_channels=reduction_chs) if use_se else None def forward(self, x): if self.downsample is not None: @@ -159,8 +125,8 @@ class Bottleneck(nn.Module): conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3), aa_layer(channels=planes, filt_size=3, stride=2)) - reduce_layer_planes = max(planes * self.expansion // 8, 64) - self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None + reduction_chs = max(planes * self.expansion // 8, 64) + self.se = SEModule(planes, reduction_channels=reduction_chs) if use_se else None self.conv3 = conv2d_iabn( planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity") @@ -189,7 +155,7 @@ class Bottleneck(nn.Module): class TResNet(nn.Module): def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False, - global_pool='avg', drop_rate=0.): + global_pool='fast', drop_rate=0.): self.num_classes = num_classes self.drop_rate = drop_rate super(TResNet, self).__init__() @@ -272,7 +238,7 @@ class TResNet(nn.Module): def get_classifier(self): return self.head.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes, global_pool='fast'): self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)