import torch import torch.nn as nn from torch.nn import functional as F from .layers import create_conv2d, create_attn, drop_path # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per # papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) # NOTE: momentum varies btw .99 and .9997 depending on source # .99 in official TF TPU impl # .9997 (/w .999 in search space) for paper BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 BN_EPS_TF_DEFAULT = 1e-3 _BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) def get_bn_args_tf(): return _BN_ARGS_TF.copy() def resolve_bn_args(kwargs): bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} bn_momentum = kwargs.pop('bn_momentum', None) if bn_momentum is not None: bn_args['momentum'] = bn_momentum bn_eps = kwargs.pop('bn_eps', None) if bn_eps is not None: bn_args['eps'] = bn_eps return bn_args def resolve_attn_args(layer, kwargs, in_chs, act_layer=None): attn_kwargs = kwargs.copy() if kwargs is not None else {} if isinstance(layer, nn.Module): is_se = 'SqueezeExciteV2' in layer.__name__ else: is_se = layer == 'sev2' if is_se: # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch if not attn_kwargs.pop('reduce_mid', False): attn_kwargs['reduced_base_chs'] = in_chs # if act_layer it is not defined by attn kwargs, the containing block's act_layer will be used for attn if attn_kwargs.get('act_layer', None) is None: assert act_layer is not None attn_kwargs['act_layer'] = act_layer return attn_kwargs def make_divisible(v, divisor=8, min_value=None): min_value = min_value or divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < 0.9 * v: new_v += divisor return new_v def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): """Round number of filters based on depth multiplier.""" if not multiplier: return channels channels *= multiplier return make_divisible(channels, divisor, channel_min) class ChannelShuffle(nn.Module): # FIXME haven't used yet def __init__(self, groups): super(ChannelShuffle, self).__init__() self.groups = groups def forward(self, x): """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" N, C, H, W = x.size() g = self.groups assert C % g == 0, "Incompatible group size {} for input channel {}".format( g, C ) return ( x.view(N, g, int(C / g), H, W) .permute(0, 2, 1, 3, 4) .contiguous() .view(N, C, H, W) ) class ConvBnAct(nn.Module): def __init__(self, in_chs, out_chs, kernel_size, stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None): super(ConvBnAct, self).__init__() norm_kwargs = norm_kwargs or {} self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type) self.bn1 = norm_layer(out_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) def feature_module(self, location): return 'act1' def feature_channels(self, location): return self.conv.out_channels def forward(self, x): x = self.conv(x) x = self.bn1(x) x = self.act1(x) return x class DepthwiseSeparableConv(nn.Module): """ DepthwiseSeparable block Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. """ def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, pw_act=False, attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.): super(DepthwiseSeparableConv, self).__init__() norm_kwargs = norm_kwargs or {} self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv self.drop_path_rate = drop_path_rate self.conv_dw = create_conv2d( in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True) self.bn1 = norm_layer(in_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Attention block (Squeeze-Excitation, ECA, etc) if attn_layer is not None: attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer) self.se = create_attn(attn_layer, in_chs, **attn_kwargs) else: self.se = None self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn2 = norm_layer(out_chs, **norm_kwargs) self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() def feature_module(self, location): # no expansion in this block, pre pw only feature extraction point return 'conv_pw' def feature_channels(self, location): return self.conv_pw.in_channels def forward(self, x): residual = x x = self.conv_dw(x) x = self.bn1(x) x = self.act1(x) if self.se is not None: x = self.se(x) x = self.conv_pw(x) x = self.bn2(x) x = self.act2(x) if self.has_residual: if self.drop_path_rate > 0.: x = drop_path(x, self.drop_path_rate, self.training) x += residual return x class InvertedResidual(nn.Module): """ Inverted residual block w/ optional SE and CondConv routing""" def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, conv_kwargs=None, drop_path_rate=0.): super(InvertedResidual, self).__init__() norm_kwargs = norm_kwargs or {} conv_kwargs = conv_kwargs or {} mid_chs = make_divisible(in_chs * exp_ratio) self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_path_rate = drop_path_rate # Point-wise expansion self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) self.bn1 = norm_layer(mid_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Depth-wise convolution self.conv_dw = create_conv2d( mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True, **conv_kwargs) self.bn2 = norm_layer(mid_chs, **norm_kwargs) self.act2 = act_layer(inplace=True) # Attention block (Squeeze-Excitation, ECA, etc) if attn_layer is not None: attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer) self.se = create_attn(attn_layer, mid_chs, **attn_kwargs) else: self.se = None # Point-wise linear projection self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) self.bn3 = norm_layer(out_chs, **norm_kwargs) def feature_module(self, location): if location == 'post_exp': return 'act1' return 'conv_pwl' def feature_channels(self, location): if location == 'post_exp': return self.conv_pw.out_channels # location == 'pre_pw' return self.conv_pwl.in_channels def forward(self, x): residual = x # Point-wise expansion x = self.conv_pw(x) x = self.bn1(x) x = self.act1(x) # Depth-wise convolution x = self.conv_dw(x) x = self.bn2(x) x = self.act2(x) # Attention if self.se is not None: x = self.se(x) # Point-wise linear projection x = self.conv_pwl(x) x = self.bn3(x) if self.has_residual: if self.drop_path_rate > 0.: x = drop_path(x, self.drop_path_rate, self.training) x += residual return x class XDepthwiseSeparableConv(nn.Module): """ DepthwiseSeparable block Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. """ def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, pw_act=False, attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.): super(XDepthwiseSeparableConv, self).__init__() norm_kwargs = norm_kwargs or {} self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv self.drop_path_rate = drop_path_rate conv_kwargs = {} self.conv_dw_2x2 = create_conv2d( in_chs, in_chs, 2, stride=stride, dilation=dilation, padding='same', depthwise=True, **conv_kwargs) self.conv_dw_1xk = create_conv2d( in_chs, in_chs, (1, dw_kernel_size), stride=stride, dilation=dilation, padding=pad_type, depthwise=True, **conv_kwargs) self.conv_dw_kx1 = create_conv2d( in_chs, in_chs, (dw_kernel_size, 1), stride=stride, dilation=dilation, padding=pad_type, depthwise=True, **conv_kwargs) self.bn1 = norm_layer(in_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Attention block (Squeeze-Excitation, ECA, etc) if attn_layer is not None: attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer) self.se = create_attn(attn_layer, in_chs, **attn_kwargs) else: self.se = None self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn2 = norm_layer(out_chs, **norm_kwargs) self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() def feature_module(self, location): # no expansion in this block, pre pw only feature extraction point return 'conv_pw' def feature_channels(self, location): return self.conv_pw.in_channels def forward(self, x): residual = x x = self.conv_dw_2x2(x) x = self.conv_dw_1xk(x) x = self.conv_dw_kx1(x) x = self.bn1(x) x = self.act1(x) if self.se is not None: x = self.se(x) x = self.conv_pw(x) x = self.bn2(x) x = self.act2(x) if self.has_residual: if self.drop_path_rate > 0.: x = drop_path(x, self.drop_path_rate, self.training) x += residual return x class XInvertedResidual(nn.Module): """ Inverted residual block w/ optional SE and CondConv routing""" def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, pad_shift=0, attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, conv_kwargs=None, drop_path_rate=0.): super(XInvertedResidual, self).__init__() norm_kwargs = norm_kwargs or {} conv_kwargs = conv_kwargs or {} mid_chs = make_divisible(in_chs * exp_ratio) self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_path_rate = drop_path_rate # Point-wise expansion self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) self.bn1 = norm_layer(mid_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Depth-wise convolution self.conv_dw_2x2 = create_conv2d( mid_chs, mid_chs, 2, stride=stride, dilation=dilation, padding='same', depthwise=True, pad_shift=pad_shift, **conv_kwargs) self.conv_dw_1xk = create_conv2d( mid_chs, mid_chs, (1, dw_kernel_size), stride=stride, dilation=dilation, padding=pad_type, depthwise=True, **conv_kwargs) self.conv_dw_kx1 = create_conv2d( mid_chs, mid_chs, (dw_kernel_size, 1), stride=stride, dilation=dilation, padding=pad_type, depthwise=True, **conv_kwargs) self.bn2 = norm_layer(mid_chs, **norm_kwargs) self.act2 = act_layer(inplace=True) # Attention block (Squeeze-Excitation, ECA, etc) if attn_layer is not None: attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer) self.se = create_attn(attn_layer, mid_chs, **attn_kwargs) else: self.se = None # Point-wise linear projection self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) self.bn3 = norm_layer(out_chs, **norm_kwargs) def feature_module(self, location): if location == 'post_exp': return 'act1' return 'conv_pwl' def feature_channels(self, location): if location == 'post_exp': return self.conv_pw.out_channels # location == 'pre_pw' return self.conv_pwl.in_channels def forward(self, x): residual = x # Point-wise expansion x = self.conv_pw(x) x = self.bn1(x) x = self.act1(x) # Depth-wise convolution x = self.conv_dw_2x2(x) x = self.conv_dw_1xk(x) x = self.conv_dw_kx1(x) x = self.bn2(x) x = self.act2(x) # Attention if self.se is not None: x = self.se(x) # Point-wise linear projection x = self.conv_pwl(x) x = self.bn3(x) if self.has_residual: if self.drop_path_rate > 0.: x = drop_path(x, self.drop_path_rate, self.training) x += residual return x class CondConvResidual(InvertedResidual): """ Inverted residual block w/ CondConv routing""" def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, num_experts=0, drop_path_rate=0.): self.num_experts = num_experts conv_kwargs = dict(num_experts=self.num_experts) super(CondConvResidual, self).__init__( in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, pw_kernel_size=pw_kernel_size, attn_layer=attn_layer, attn_kwargs=attn_kwargs, norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs, drop_path_rate=drop_path_rate) self.routing_fn = nn.Linear(in_chs, self.num_experts) def forward(self, x): residual = x # CondConv routing pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) # Point-wise expansion x = self.conv_pw(x, routing_weights) x = self.bn1(x) x = self.act1(x) # Depth-wise convolution x = self.conv_dw(x, routing_weights) x = self.bn2(x) x = self.act2(x) # Attention if self.se is not None: x = self.se(x) # Point-wise linear projection x = self.conv_pwl(x, routing_weights) x = self.bn3(x) if self.has_residual: if self.drop_path_rate > 0.: x = drop_path(x, self.drop_path_rate, self.training) x += residual return x class EdgeResidual(nn.Module): """ Residual block with expansion convolution followed by pointwise-linear w/ stride""" def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.): super(EdgeResidual, self).__init__() norm_kwargs = norm_kwargs or {} if fake_in_chs > 0: mid_chs = make_divisible(fake_in_chs * exp_ratio) else: mid_chs = make_divisible(in_chs * exp_ratio) self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_path_rate = drop_path_rate # Expansion convolution self.conv_exp = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) self.bn1 = norm_layer(mid_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Attention block (Squeeze-Excitation, ECA, etc) if attn_layer is not None: attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer) self.se = create_attn(attn_layer, mid_chs, **attn_kwargs) else: self.se = None # Point-wise linear projection self.conv_pwl = create_conv2d( mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type) self.bn2 = norm_layer(out_chs, **norm_kwargs) def feature_module(self, location): if location == 'post_exp': return 'act1' return 'conv_pwl' def feature_channels(self, location): if location == 'post_exp': return self.conv_exp.out_channels # location == 'pre_pw' return self.conv_pwl.in_channels def forward(self, x): residual = x # Expansion convolution x = self.conv_exp(x) x = self.bn1(x) x = self.act1(x) # Attention if self.se is not None: x = self.se(x) # Point-wise linear projection x = self.conv_pwl(x) x = self.bn2(x) if self.has_residual: if self.drop_path_rate > 0.: x = drop_path(x, self.drop_path_rate, self.training) x += residual return x