Add 'fast' global pool option, remove redundant SEModule from tresnet, normal one is now 'fast'

pull/233/head
Ross Wightman 4 years ago
parent 90a01f47d1
commit 80c9d9cc72

@ -49,6 +49,15 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
return x return x
class FastAdaptiveAvgPool2d(nn.Module):
def __init__(self, flatten=False):
super(FastAdaptiveAvgPool2d, self).__init__()
self.flatten = flatten
def forward(self, x):
return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True)
class AdaptiveAvgMaxPool2d(nn.Module): class AdaptiveAvgMaxPool2d(nn.Module):
def __init__(self, output_size=1): def __init__(self, output_size=1):
super(AdaptiveAvgMaxPool2d, self).__init__() super(AdaptiveAvgMaxPool2d, self).__init__()
@ -70,12 +79,16 @@ class AdaptiveCatAvgMaxPool2d(nn.Module):
class SelectAdaptivePool2d(nn.Module): class SelectAdaptivePool2d(nn.Module):
"""Selectable global pooling layer with dynamic input kernel size """Selectable global pooling layer with dynamic input kernel size
""" """
def __init__(self, output_size=1, pool_type='avg', flatten=False): def __init__(self, output_size=1, pool_type='fast', flatten=False):
super(SelectAdaptivePool2d, self).__init__() super(SelectAdaptivePool2d, self).__init__()
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
self.flatten = flatten self.flatten = flatten
if pool_type == '': if pool_type == '':
self.pool = nn.Identity() # pass through self.pool = nn.Identity() # pass through
elif pool_type == 'fast':
assert output_size == 1
self.pool = FastAdaptiveAvgPool2d(self.flatten)
self.flatten = False
elif pool_type == 'avg': elif pool_type == 'avg':
self.pool = nn.AdaptiveAvgPool2d(output_size) self.pool = nn.AdaptiveAvgPool2d(output_size)
elif pool_type == 'avgmax': elif pool_type == 'avgmax':

@ -14,7 +14,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule
from .registry import register_model from .registry import register_model
__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] __all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl']
@ -49,40 +49,6 @@ default_cfgs = {
} }
class FastGlobalAvgPool2d(nn.Module):
def __init__(self, flatten=False):
super(FastGlobalAvgPool2d, self).__init__()
self.flatten = flatten
def forward(self, x):
if self.flatten:
return x.mean((2, 3))
else:
return x.mean((2, 3), keepdim=True)
def feat_mult(self):
return 1
class FastSEModule(nn.Module):
def __init__(self, channels, reduction_channels, inplace=True):
super(FastSEModule, self).__init__()
self.avg_pool = FastGlobalAvgPool2d()
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.relu = nn.ReLU(inplace=inplace)
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, padding=0, bias=True)
self.activation = nn.Sigmoid()
def forward(self, x):
x_se = self.avg_pool(x)
x_se2 = self.fc1(x_se)
x_se2 = self.relu(x_se2)
x_se = self.fc2(x_se2)
x_se = self.activation(x_se)
return x * x_se
def IABN2Float(module: nn.Module) -> nn.Module: def IABN2Float(module: nn.Module) -> nn.Module:
"""If `module` is IABN don't use half precision.""" """If `module` is IABN don't use half precision."""
if isinstance(module, InplaceAbn): if isinstance(module, InplaceAbn):
@ -119,8 +85,8 @@ class BasicBlock(nn.Module):
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
reduce_layer_planes = max(planes * self.expansion // 4, 64) reduction_chs = max(planes * self.expansion // 4, 64)
self.se = FastSEModule(planes * self.expansion, reduce_layer_planes) if use_se else None self.se = SEModule(planes * self.expansion, reduction_channels=reduction_chs) if use_se else None
def forward(self, x): def forward(self, x):
if self.downsample is not None: if self.downsample is not None:
@ -159,8 +125,8 @@ class Bottleneck(nn.Module):
conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3), conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3),
aa_layer(channels=planes, filt_size=3, stride=2)) aa_layer(channels=planes, filt_size=3, stride=2))
reduce_layer_planes = max(planes * self.expansion // 8, 64) reduction_chs = max(planes * self.expansion // 8, 64)
self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None self.se = SEModule(planes, reduction_channels=reduction_chs) if use_se else None
self.conv3 = conv2d_iabn( self.conv3 = conv2d_iabn(
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity") planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")
@ -189,7 +155,7 @@ class Bottleneck(nn.Module):
class TResNet(nn.Module): class TResNet(nn.Module):
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False, def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False,
global_pool='avg', drop_rate=0.): global_pool='fast', drop_rate=0.):
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
super(TResNet, self).__init__() super(TResNet, self).__init__()
@ -272,7 +238,7 @@ class TResNet(nn.Module):
def get_classifier(self): def get_classifier(self):
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='fast'):
self.head = ClassifierHead( self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)

Loading…
Cancel
Save