diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 08bc1699..4ef966ea 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -8,6 +8,7 @@ from .xception import * from .nasnet import * from .pnasnet import * from .gen_efficientnet import * +from .mobilenetv3 import * from .inception_v3 import * from .gluon_resnet import * from .gluon_xception import * diff --git a/timm/models/activations.py b/timm/models/activations.py index aa29b84d..aafa290c 100644 --- a/timm/models/activations.py +++ b/timm/models/activations.py @@ -7,72 +7,64 @@ _USE_MEM_EFFICIENT_ISH = True if _USE_MEM_EFFICIENT_ISH: # This version reduces memory overhead of Swish during training by # recomputing torch.sigmoid(x) in backward instead of saving it. - class SwishAutoFn(torch.autograd.Function): - """Swish - Described in: https://arxiv.org/abs/1710.05941 - Memory efficient variant from: - https://medium.com/the-artificial-impostor/more-memory-efficient-swish-activation-function-e07c22c12a76 - """ - @staticmethod - def forward(ctx, x): - result = x.mul(torch.sigmoid(x)) - ctx.save_for_backward(x) - return result + @torch.jit.script + def swish_jit_fwd(x): + return x.mul(torch.sigmoid(x)) - @staticmethod - def backward(ctx, grad_output): - x = ctx.saved_variables[0] - sigmoid_x = torch.sigmoid(x) - return grad_output.mul(sigmoid_x * (1 + x * (1 - sigmoid_x))) - def swish(x, inplace=False): - # inplace ignored - return SwishAutoFn.apply(x) + @torch.jit.script + def swish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) - class MishAutoFn(torch.autograd.Function): - """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 - Experimental memory-efficient variant + class SwishJitAutoFn(torch.autograd.Function): + """ torch.jit.script optimised Swish + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 """ @staticmethod def forward(ctx, x): ctx.save_for_backward(x) - y = x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x))) - return y + return swish_jit_fwd(x) @staticmethod def backward(ctx, grad_output): - x = ctx.saved_variables[0] - x_sigmoid = torch.sigmoid(x) - x_tanh_sp = F.softplus(x).tanh() - return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + x = ctx.saved_tensors[0] + return swish_jit_bwd(x, grad_output) - def mish(x, inplace=False): - # inplace ignored - return MishAutoFn.apply(x) + def swish(x, _inplace=False): + return SwishJitAutoFn.apply(x) - class WishAutoFn(torch.autograd.Function): - """Wish: My own mistaken creation while fiddling with Mish. Did well in some experiments. - Experimental memory-efficient variant - """ + @torch.jit.script + def mish_jit_fwd(x): + return x.mul(torch.tanh(F.softplus(x))) + + + @torch.jit.script + def mish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + + class MishJitAutoFn(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) - y = x.mul(torch.tanh(torch.exp(x))) - return y + return mish_jit_fwd(x) @staticmethod def backward(ctx, grad_output): - x = ctx.saved_variables[0] - x_exp = x.exp() - x_tanh_exp = x_exp.tanh() - return grad_output.mul(x_tanh_exp + x * x_exp * (1 - x_tanh_exp * x_tanh_exp)) - - def wish(x, inplace=False): - # inplace ignored - return WishAutoFn.apply(x) + x = ctx.saved_tensors[0] + return mish_jit_bwd(x, grad_output) + + def mish(x, _inplace=False): + return MishJitAutoFn.apply(x) + else: def swish(x, inplace=False): """Swish - Described in: https://arxiv.org/abs/1710.05941 @@ -80,18 +72,10 @@ else: return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) - def mish(x, inplace=False): + def mish(x, _inplace=False): """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 """ - inner = F.softplus(x).tanh() - return x.mul_(inner) if inplace else x.mul(inner) - - - def wish(x, inplace=False): - """Wish: My own mistaken creation while fiddling with Mish. Did well in some experiments. - """ - inner = x.exp().tanh() - return x.mul_(inner) if inplace else x.mul(inner) + return x.mul(F.softplus(x).tanh()) class Swish(nn.Module): @@ -112,15 +96,6 @@ class Mish(nn.Module): return mish(x, self.inplace) -class Wish(nn.Module): - def __init__(self, inplace=False): - super(Wish, self).__init__() - self.inplace = inplace - - def forward(self, x): - return wish(x, self.inplace) - - def sigmoid(x, inplace=False): return x.sigmoid_() if inplace else x.sigmoid() diff --git a/timm/models/conv2d_layers.py b/timm/models/conv2d_layers.py index ea72d07c..acd14fde 100644 --- a/timm/models/conv2d_layers.py +++ b/timm/models/conv2d_layers.py @@ -102,13 +102,14 @@ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): class MixedConv2d(nn.Module): """ Mixed Grouped Convolution - Based on MDConv and GroupedConv in MixNet impl: https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + + NOTE: This does not currently work with torch.jit.script """ def __init__(self, in_channels, out_channels, kernel_size=3, - stride=1, padding='', dilation=1, mixed_dilated=False, depthwise=False, **kwargs): + stride=1, padding='', dilation=1, depthwise=False, **kwargs): super(MixedConv2d, self).__init__() kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] @@ -118,17 +119,13 @@ class MixedConv2d(nn.Module): self.in_channels = sum(in_splits) self.out_channels = sum(out_splits) for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): - d = dilation - # FIXME make compat with non-square kernel/dilations/strides - if stride == 1 and mixed_dilated: - d, k = (k - 1) // 2, 3 conv_groups = out_ch if depthwise else 1 # use add_module to keep key space clean self.add_module( str(idx), create_conv2d_pad( in_ch, out_ch, k, stride=stride, - padding=padding, dilation=d, groups=conv_groups, **kwargs) + padding=padding, dilation=dilation, groups=conv_groups, **kwargs) ) self.splits = in_splits @@ -154,12 +151,12 @@ def get_condconv_initializer(initializer, num_experts, expert_shape): class CondConv2d(nn.Module): """ Conditional Convolution - Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: https://github.com/pytorch/pytorch/issues/17983 """ + __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): @@ -171,13 +168,10 @@ class CondConv2d(nn.Module): self.stride = _pair(stride) padding_val, is_padding_dynamic = get_padding_value( padding, kernel_size, stride=stride, dilation=dilation) - self.conv_fn = conv2d_same if is_padding_dynamic else F.conv2d + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript self.padding = _pair(padding_val) self.dilation = _pair(dilation) - self.transposed = False - self.output_padding = _pair(0) self.groups = groups - self.padding_mode = 'zero' self.num_experts = num_experts self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size @@ -186,24 +180,19 @@ class CondConv2d(nn.Module): weight_num_param *= wd self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) - # FIXME I haven't tested bias yet if bias: self.bias_shape = (self.out_channels,) - condconv_bias_shape = (self.num_experts, self.out_channels) - self.bias = torch.nn.Parameter(torch.Tensor(condconv_bias_shape)) + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() - # FIXME once I'm satisfied this works, remove the looping path? - self._use_groups = True # use groups for parallel per-batch-element kernel convolution def reset_parameters(self): init_weight = get_condconv_initializer( partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) init_weight(self.weight) if self.bias is not None: - # FIXME bias not tested fan_in = np.prod(self.weight_shape[1:]) bound = 1 / math.sqrt(fan_in) init_bias = get_condconv_initializer( @@ -211,35 +200,43 @@ class CondConv2d(nn.Module): init_bias(self.bias) def forward(self, x, routing_weights): - weight = torch.matmul(routing_weights, self.weight) - bias = torch.matmul(routing_weights, self.bias) if self.bias is not None else None B, C, H, W = x.shape - if self._use_groups: - new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size - weight = weight.view(new_weight_shape) - # move batch elements with channels so each batch element can be efficiently convolved with separate kernel - x = x.view(1, B * C, H, W) - out = self.conv_fn( + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + x = x.view(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( x, weight, bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups * B) - out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) else: - x = torch.split(x, 1, 0) - weight = torch.split(weight, 1, 0) - if self.bias is not None: - bias = torch.matmul(routing_weights, self.bias) - bias = torch.split(bias, 1, 0) - else: - bias = [None] * B - out = [] - for xi, wi, bi in zip(x, weight, bias): - wi = wi.view(*self.weight_shape) - if bi is not None: - bi = bi.view(*self.bias_shape) - out.append(self.conv_fn( - xi, wi, bi, stride=self.stride, padding=self.padding, - dilation=self.dilation, groups=self.groups)) - out = torch.cat(out, 0) + out = F.conv2d( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) return out @@ -250,13 +247,14 @@ def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently # We're going to use only lists for defining the MixedConv2d kernel groups, # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. - return MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) + m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) else: depthwise = kwargs.pop('depthwise', False) groups = out_chs if depthwise else 1 if 'num_experts' in kwargs and kwargs['num_experts'] > 0: - create_fn = CondConv2d + m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) else: - create_fn = create_conv2d_pad - return create_fn(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + return m + diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py new file mode 100644 index 00000000..13ab051a --- /dev/null +++ b/timm/models/efficientnet_blocks.py @@ -0,0 +1,404 @@ + +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .activations import sigmoid +from .conv2d_layers import * + + +# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per +# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) +# NOTE: momentum varies btw .99 and .9997 depending on source +# .99 in official TF TPU impl +# .9997 (/w .999 in search space) for paper +BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 +BN_EPS_TF_DEFAULT = 1e-3 +_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) + + +def get_bn_args_tf(): + return _BN_ARGS_TF.copy() + + +def resolve_bn_args(kwargs): + bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} + bn_momentum = kwargs.pop('bn_momentum', None) + if bn_momentum is not None: + bn_args['momentum'] = bn_momentum + bn_eps = kwargs.pop('bn_eps', None) + if bn_eps is not None: + bn_args['eps'] = bn_eps + return bn_args + + +_SE_ARGS_DEFAULT = dict( + gate_fn=sigmoid, + act_layer=None, + reduce_mid=False, + divisor=1) + + +def resolve_se_args(kwargs, in_chs, act_layer=None): + se_kwargs = kwargs.copy() if kwargs is not None else {} + # fill in args that aren't specified with the defaults + for k, v in _SE_ARGS_DEFAULT.items(): + se_kwargs.setdefault(k, v) + # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch + if not se_kwargs.pop('reduce_mid'): + se_kwargs['reduced_base_chs'] = in_chs + # act_layer override, if it remains None, the containing block's act_layer will be used + if se_kwargs['act_layer'] is None: + assert act_layer is not None + se_kwargs['act_layer'] = act_layer + return se_kwargs + + +def make_divisible(v, divisor=8, min_value=None): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): + """Round number of filters based on depth multiplier.""" + if not multiplier: + return channels + channels *= multiplier + return make_divisible(channels, divisor, channel_min) + + +def drop_connect(inputs, training=False, drop_connect_rate=0.): + """Apply drop connect.""" + if not training: + return inputs + + keep_prob = 1 - drop_connect_rate + random_tensor = keep_prob + torch.rand( + (inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device) + random_tensor.floor_() # binarize + output = inputs.div(keep_prob) * random_tensor + return output + + +class ChannelShuffle(nn.Module): + # FIXME haven't used yet + def __init__(self, groups): + super(ChannelShuffle, self).__init__() + self.groups = groups + + def forward(self, x): + """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" + N, C, H, W = x.size() + g = self.groups + assert C % g == 0, "Incompatible group size {} for input channel {}".format( + g, C + ) + return ( + x.view(N, g, int(C / g), H, W) + .permute(0, 2, 1, 3, 4) + .contiguous() + .view(N, C, H, W) + ) + + +class SqueezeExcite(nn.Module): + def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, + act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_): + super(SqueezeExcite, self).__init__() + self.gate_fn = gate_fn + reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + + def forward(self, x): + x_se = self.avg_pool(x) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +class ConvBnAct(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, norm_kwargs=None): + super(ConvBnAct, self).__init__() + norm_kwargs = norm_kwargs or {} + self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type) + self.bn1 = norm_layer(out_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + def feature_module(self, location): + return 'act1' + + def feature_channels(self, location): + return self.conv.out_channels + + def forward(self, x): + x = self.conv(x) + x = self.bn1(x) + x = self.act1(x) + return x + + +class DepthwiseSeparableConv(nn.Module): + """ DepthwiseSeparable block + Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion + (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. + """ + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, + pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + super(DepthwiseSeparableConv, self).__init__() + norm_kwargs = norm_kwargs or {} + self.has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + self.has_pw_act = pw_act # activation after point-wise conv + self.drop_connect_rate = drop_connect_rate + + self.conv_dw = select_conv2d( + in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True) + self.bn1 = norm_layer(in_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + if self.has_se: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) + + self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) + self.bn2 = norm_layer(out_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() + + def feature_module(self, location): + # no expansion in this block, pre pw only feature extraction point + return 'conv_pw' + + def feature_channels(self, location): + return self.conv_pw.in_channels + + def forward(self, x): + residual = x + + x = self.conv_dw(x) + x = self.bn1(x) + x = self.act1(x) + + if self.has_se: + x = self.se(x) + + x = self.conv_pw(x) + x = self.bn2(x) + x = self.act2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class InvertedResidual(nn.Module): + """ Inverted residual block w/ optional SE and CondConv routing""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + conv_kwargs=None, drop_connect_rate=0.): + super(InvertedResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + conv_kwargs = conv_kwargs or {} + mid_chs = make_divisible(in_chs * exp_ratio) + self.has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_connect_rate = drop_connect_rate + + # Point-wise expansion + self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Depth-wise convolution + self.conv_dw = select_conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, + padding=pad_type, depthwise=True, **conv_kwargs) + self.bn2 = norm_layer(mid_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) + + # Squeeze-and-excitation + if self.has_se: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + + # Point-wise linear projection + self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) + self.bn3 = norm_layer(out_chs, **norm_kwargs) + + def feature_module(self, location): + if location == 'post_exp': + return 'act1' + return 'conv_pwl' + + def feature_channels(self, location): + if location == 'post_exp': + return self.conv_pw.out_channels + # location == 'pre_pw' + return self.conv_pwl.in_channels + + def forward(self, x): + residual = x + + # Point-wise expansion + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + if self.has_se: + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + + return x + + +class CondConvResidual(InvertedResidual): + """ Inverted residual block w/ CondConv routing""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + num_experts=0, drop_connect_rate=0.): + + self.num_experts = num_experts + conv_kwargs = dict(num_experts=self.num_experts) + + super(CondConvResidual, self).__init__( + in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type, + act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs, + drop_connect_rate=drop_connect_rate) + + self.routing_fn = nn.Linear(in_chs, self.num_experts) + + def forward(self, x): + residual = x + + # CondConv routing + pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) + routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) + + # Point-wise expansion + x = self.conv_pw(x, routing_weights) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x, routing_weights) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + if self.has_se: + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x, routing_weights) + x = self.bn3(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class EdgeResidual(nn.Module): + """ Residual block with expansion convolution followed by pointwise-linear w/ stride""" + + def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + drop_connect_rate=0.): + super(EdgeResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + if fake_in_chs > 0: + mid_chs = make_divisible(fake_in_chs * exp_ratio) + else: + mid_chs = make_divisible(in_chs * exp_ratio) + self.has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_connect_rate = drop_connect_rate + + # Expansion convolution + self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + if self.has_se: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + + # Point-wise linear projection + self.conv_pwl = select_conv2d( + mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type) + self.bn2 = norm_layer(out_chs, **norm_kwargs) + + def feature_module(self, location): + if location == 'post_exp': + return 'act1' + return 'conv_pwl' + + def feature_channels(self, location): + if location == 'post_exp': + return self.conv_exp.out_channels + # location == 'pre_pw' + return self.conv_pwl.in_channels + + def forward(self, x): + residual = x + + # Expansion convolution + x = self.conv_exp(x) + x = self.bn1(x) + x = self.act1(x) + + # Squeeze-and-excitation + if self.has_se: + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + + return x diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py new file mode 100644 index 00000000..c2b3a801 --- /dev/null +++ b/timm/models/efficientnet_builder.py @@ -0,0 +1,402 @@ +import logging +import math +import re +from collections.__init__ import OrderedDict +from copy import deepcopy + +import torch.nn as nn +from .activations import sigmoid, HardSwish, Swish +from .efficientnet_blocks import * + + +def _parse_ksize(ss): + if ss.isdigit(): + return int(ss) + else: + return [int(k) for k in ss.split('.')] + + +def _decode_block_str(block_str): + """ Decode block definition string + + Gets a list of block arg (dicts) through a string notation of arguments. + E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip + + All args can exist in any order with the exception of the leading string which + is assumed to indicate the block type. + + leading string - block type ( + ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) + r - number of repeat blocks, + k - kernel size, + s - strides (1-9), + e - expansion ratio, + c - output channels, + se - squeeze/excitation ratio + n - activation fn ('re', 'r6', 'hs', or 'sw') + Args: + block_str: a string representation of block arguments. + Returns: + A list of block args (dicts) + Raises: + ValueError: if the string def not properly specified (TODO) + """ + assert isinstance(block_str, str) + ops = block_str.split('_') + block_type = ops[0] # take the block type off the front + ops = ops[1:] + options = {} + noskip = False + for op in ops: + # string options being checked on individual basis, combine if they grow + if op == 'noskip': + noskip = True + elif op.startswith('n'): + # activation fn + key = op[0] + v = op[1:] + if v == 're': + value = nn.ReLU + elif v == 'r6': + value = nn.ReLU6 + elif v == 'hs': + value = HardSwish + elif v == 'sw': + value = Swish + else: + continue + options[key] = value + else: + # all numeric options + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # if act_layer is None, the model default (passed to model init) will be used + act_layer = options['n'] if 'n' in options else None + exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 + pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 + fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def + + num_repeat = int(options['r']) + # each type of block has different valid arguments, fill accordingly + if block_type == 'ir': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + if 'cc' in options: + block_args['num_experts'] = int(options['cc']) + elif block_type == 'ds' or block_type == 'dsa': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + pw_act=block_type == 'dsa', + noskip=block_type == 'dsa' or noskip, + ) + elif block_type == 'er': + block_args = dict( + block_type=block_type, + exp_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + fake_in_chs=fake_in_chs, + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + elif block_type == 'cn': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + stride=int(options['s']), + act_layer=act_layer, + ) + else: + assert False, 'Unknown block type (%s)' % block_type + + return block_args, num_repeat + + +def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): + """ Per-stage depth scaling + Scales the block repeats in each stage. This depth scaling impl maintains + compatibility with the EfficientNet scaling method, while allowing sensible + scaling for other models that may have multiple block arg definitions in each stage. + """ + + # We scale the total repeat count for each stage, there may be multiple + # block arg defs per stage so we need to sum. + num_repeat = sum(repeats) + if depth_trunc == 'round': + # Truncating to int by rounding allows stages with few repeats to remain + # proportionally smaller for longer. This is a good choice when stage definitions + # include single repeat stages that we'd prefer to keep that way as long as possible + num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) + else: + # The default for EfficientNet truncates repeats to int via 'ceil'. + # Any multiplier > 1.0 will result in an increased depth for every stage. + num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) + + # Proportionally distribute repeat count scaling to each block definition in the stage. + # Allocation is done in reverse as it results in the first block being less likely to be scaled. + # The first block makes less sense to repeat in most of the arch definitions. + repeats_scaled = [] + for r in repeats[::-1]: + rs = max(1, round((r / num_repeat * num_repeat_scaled))) + repeats_scaled.append(rs) + num_repeat -= r + num_repeat_scaled -= rs + repeats_scaled = repeats_scaled[::-1] + + # Apply the calculated scaling to each block arg in the stage + sa_scaled = [] + for ba, rep in zip(stack_args, repeats_scaled): + sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) + return sa_scaled + + +def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1): + arch_args = [] + for stack_idx, block_strings in enumerate(arch_def): + assert isinstance(block_strings, list) + stack_args = [] + repeats = [] + for block_str in block_strings: + assert isinstance(block_str, str) + ba, rep = _decode_block_str(block_str) + if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: + ba['num_experts'] *= experts_multiplier + stack_args.append(ba) + repeats.append(rep) + arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) + return arch_args + + +class EfficientNetBuilder: + """ Build Trunk Blocks + + This ended up being somewhat of a cross between + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py + and + https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py + + """ + def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, + output_stride=32, pad_type='', act_layer=None, se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0., feature_location='', + verbose=False): + self.channel_multiplier = channel_multiplier + self.channel_divisor = channel_divisor + self.channel_min = channel_min + self.output_stride = output_stride + self.pad_type = pad_type + self.act_layer = act_layer + self.se_kwargs = se_kwargs + self.norm_layer = norm_layer + self.norm_kwargs = norm_kwargs + self.drop_connect_rate = drop_connect_rate + self.feature_location = feature_location + assert feature_location in ('pre_pwl', 'post_exp', '') + self.verbose = verbose + + # state updated during build, consumed by model + self.in_chs = None + self.features = OrderedDict() + + def _round_channels(self, chs): + return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) + + def _make_block(self, ba, block_idx, block_count): + drop_connect_rate = self.drop_connect_rate * block_idx / block_count + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = self._round_channels(ba['out_chs']) + if 'fake_in_chs' in ba and ba['fake_in_chs']: + # FIXME this is a hack to work around mismatch in origin impl input filters + ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) + ba['norm_layer'] = self.norm_layer + ba['norm_kwargs'] = self.norm_kwargs + ba['pad_type'] = self.pad_type + # block act fn overrides the model default + ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer + assert ba['act_layer'] is not None + if bt == 'ir': + ba['drop_connect_rate'] = drop_connect_rate + ba['se_kwargs'] = self.se_kwargs + if self.verbose: + logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba))) + if ba.get('num_experts', 0) > 0: + block = CondConvResidual(**ba) + else: + block = InvertedResidual(**ba) + elif bt == 'ds' or bt == 'dsa': + ba['drop_connect_rate'] = drop_connect_rate + ba['se_kwargs'] = self.se_kwargs + if self.verbose: + logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba))) + block = DepthwiseSeparableConv(**ba) + elif bt == 'er': + ba['drop_connect_rate'] = drop_connect_rate + ba['se_kwargs'] = self.se_kwargs + if self.verbose: + logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba))) + block = EdgeResidual(**ba) + elif bt == 'cn': + if self.verbose: + logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba))) + block = ConvBnAct(**ba) + else: + assert False, 'Uknkown block type (%s) while building model.' % bt + self.in_chs = ba['out_chs'] # update in_chs for arg of next block + + return block + + def __call__(self, in_chs, model_block_args): + """ Build the blocks + Args: + in_chs: Number of input-channels passed to first block + model_block_args: A list of lists, outer list defines stages, inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + if self.verbose: + logging.info('Building model trunk with %d stages...' % len(model_block_args)) + self.in_chs = in_chs + total_block_count = sum([len(x) for x in model_block_args]) + total_block_idx = 0 + current_stride = 2 + current_dilation = 1 + feature_idx = 0 + stages = [] + # outer list of block_args defines the stacks ('stages' by some conventions) + for stage_idx, stage_block_args in enumerate(model_block_args): + last_stack = stage_idx == (len(model_block_args) - 1) + if self.verbose: + logging.info('Stack: {}'.format(stage_idx)) + assert isinstance(stage_block_args, list) + + blocks = [] + # each stack (stage) contains a list of block arguments + for block_idx, block_args in enumerate(stage_block_args): + last_block = block_idx == (len(stage_block_args) - 1) + extract_features = '' # No features extracted + if self.verbose: + logging.info(' Block: {}'.format(block_idx)) + + # Sort out stride, dilation, and feature extraction details + assert block_args['stride'] in (1, 2) + if block_idx >= 1: + # only the first block in any stack can have a stride > 1 + block_args['stride'] = 1 + + do_extract = False + if self.feature_location == 'pre_pwl': + if last_block: + next_stage_idx = stage_idx + 1 + if next_stage_idx >= len(model_block_args): + do_extract = True + else: + do_extract = model_block_args[next_stage_idx][0]['stride'] > 1 + elif self.feature_location == 'post_exp': + if block_args['stride'] > 1 or (last_stack and last_block) : + do_extract = True + if do_extract: + extract_features = self.feature_location + + next_dilation = current_dilation + if block_args['stride'] > 1: + next_output_stride = current_stride * block_args['stride'] + if next_output_stride > self.output_stride: + next_dilation = current_dilation * block_args['stride'] + block_args['stride'] = 1 + if self.verbose: + logging.info(' Converting stride to dilation to maintain output_stride=={}'.format( + self.output_stride)) + else: + current_stride = next_output_stride + block_args['dilation'] = current_dilation + if next_dilation != current_dilation: + current_dilation = next_dilation + + # create the block + block = self._make_block(block_args, total_block_idx, total_block_count) + blocks.append(block) + + # stash feature module name and channel info for model feature extraction + if extract_features: + feature_module = block.feature_module(extract_features) + if feature_module: + feature_module = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_module + feature_channels = block.feature_channels(extract_features) + self.features[feature_idx] = dict( + name=feature_module, + num_chs=feature_channels + ) + feature_idx += 1 + + total_block_idx += 1 # incr global block idx (across all stacks) + stages.append(nn.Sequential(*blocks)) + return stages + + +def efficientnet_init_goog(m, n=''): + # weight init as per Tensorflow Official impl + # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py + if isinstance(m, CondConv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + init_weight_fn = get_condconv_initializer( + lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) + init_weight_fn(m.weight) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + fan_out = m.weight.size(0) # fan-out + fan_in = 0 + if 'routing_fn' in n: + fan_in = m.weight.size(1) + init_range = 1.0 / math.sqrt(fan_in + fan_out) + m.weight.data.uniform_(-init_range, init_range) + m.bias.data.zero_() + + +def efficientnet_init_default(m, n=''): + if isinstance(m, CondConv2d): + init_fn = get_condconv_initializer(partial( + nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape) + init_fn(m.weight) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') + + diff --git a/timm/models/feature_hooks.py b/timm/models/feature_hooks.py new file mode 100644 index 00000000..8ffcda86 --- /dev/null +++ b/timm/models/feature_hooks.py @@ -0,0 +1,31 @@ +from collections import defaultdict, OrderedDict +from functools import partial + + +class FeatureHooks: + + def __init__(self, hooks, named_modules): + # setup feature hooks + modules = {k: v for k, v in named_modules} + for h in hooks: + hook_name = h['name'] + m = modules[hook_name] + hook_fn = partial(self._collect_output_hook, hook_name) + if h['type'] == 'forward_pre': + m.register_forward_pre_hook(hook_fn) + elif h['type'] == 'forward': + m.register_forward_hook(hook_fn) + else: + assert False, "Unsupported hook type" + self._feature_outputs = defaultdict(OrderedDict) + + def _collect_output_hook(self, name, *args): + x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre + if isinstance(x, tuple): + x = x[0] # unwrap input tuple + self._feature_outputs[x.device][name] = x + + def get_output(self, device): + output = tuple(self._feature_outputs[device].values())[::-1] + self._feature_outputs[device] = OrderedDict() # clear after reading + return output diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py index c3b1b0e2..fe20ff13 100644 --- a/timm/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -7,8 +7,7 @@ A generic class with building blocks to support a variety of models with efficie * MixNet (Small, Medium, and Large) * MnasNet B1, A1 (SE), Small * MobileNet V1, V2, and V3 -* FBNet-C (TODO A & B) -* ChamNet (TODO still guessing at architecture definition) +* FBNet-C * Single-Path NAS Pixel1 * And likely more... @@ -16,28 +15,16 @@ TODO not all combinations and variations have been tested. Currently working on Hacked together by Ross Wightman """ - -import math -import re -import logging -from copy import deepcopy -from functools import partial -from collections import OrderedDict, defaultdict - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from timm.models.activations import Swish, sigmoid, HardSwish, hard_sigmoid -from .registry import register_model, model_entrypoint +from .efficientnet_builder import * +from .feature_hooks import FeatureHooks +from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d from .conv2d_layers import select_conv2d -from .layers import Flatten from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -__all__ = ['GenEfficientNet'] +__all__ = ['EfficientNet'] def _cfg(url='', **kwargs): @@ -62,14 +49,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'), 'semnasnet_140': _cfg(url=''), 'mnasnet_small': _cfg(url=''), - 'mobilenetv1_100': _cfg(url=''), 'mobilenetv2_100': _cfg(url=''), - 'mobilenetv3_050': _cfg(url=''), - 'mobilenetv3_075': _cfg(url=''), - 'mobilenetv3_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth'), - 'chamnetv1_100': _cfg(url=''), - 'chamnetv2_100': _cfg(url=''), 'fbnetc_100': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', interpolation='bilinear'), @@ -94,14 +74,14 @@ default_cfgs = { url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), 'efficientnet_b7': _cfg( url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'efficientnet_b8': _cfg( + url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), 'efficientnet_es': _cfg( url=''), 'efficientnet_em': _cfg( - url='', - input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), 'efficientnet_el': _cfg( - url='', - input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), 'efficientnet_cc_b0_4e': _cfg(url=''), 'efficientnet_cc_b0_8e': _cfg(url=''), 'efficientnet_cc_b1_8e': _cfg(url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), @@ -129,6 +109,41 @@ default_cfgs = { 'tf_efficientnet_b7': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_b0_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, input_size=(3, 224, 224)), + 'tf_efficientnet_b1_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_b8_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), 'tf_efficientnet_es': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), @@ -169,896 +184,72 @@ default_cfgs = { } -_DEBUG = True - -# Default args for PyTorch BN impl -_BN_MOMENTUM_PT_DEFAULT = 0.1 -_BN_EPS_PT_DEFAULT = 1e-5 -_BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT) - -# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per -# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) -# NOTE: momentum varies btw .99 and .9997 depending on source -# .99 in official TF TPU impl -# .9997 (/w .999 in search space) for paper -_BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 -_BN_EPS_TF_DEFAULT = 1e-3 -_BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT) - - -def _resolve_bn_args(kwargs): - bn_args = _BN_ARGS_TF.copy() if kwargs.pop('bn_tf', False) else _BN_ARGS_PT.copy() - bn_momentum = kwargs.pop('bn_momentum', None) - if bn_momentum is not None: - bn_args['momentum'] = bn_momentum - bn_eps = kwargs.pop('bn_eps', None) - if bn_eps is not None: - bn_args['eps'] = bn_eps - return bn_args - - -def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): - """Round number of filters based on depth multiplier.""" - if not multiplier: - return channels - - channels *= multiplier - channel_min = channel_min or divisor - new_channels = max( - int(channels + divisor / 2) // divisor * divisor, - channel_min) - # Make sure that round down does not go down by more than 10%. - if new_channels < 0.9 * channels: - new_channels += divisor - return new_channels - - -def _parse_ksize(ss): - if ss.isdigit(): - return int(ss) - else: - return [int(k) for k in ss.split('.')] - - -def _decode_block_str(block_str): - """ Decode block definition string - - Gets a list of block arg (dicts) through a string notation of arguments. - E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip - - All args can exist in any order with the exception of the leading string which - is assumed to indicate the block type. - - leading string - block type ( - ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) - r - number of repeat blocks, - k - kernel size, - s - strides (1-9), - e - expansion ratio, - c - output channels, - se - squeeze/excitation ratio - n - activation fn ('re', 'r6', 'hs', or 'sw') - Args: - block_str: a string representation of block arguments. - Returns: - A list of block args (dicts) - Raises: - ValueError: if the string def not properly specified (TODO) - """ - assert isinstance(block_str, str) - ops = block_str.split('_') - block_type = ops[0] # take the block type off the front - ops = ops[1:] - options = {} - noskip = False - for op in ops: - # string options being checked on individual basis, combine if they grow - if op == 'noskip': - noskip = True - elif op.startswith('n'): - # activation fn - key = op[0] - v = op[1:] - if v == 're': - value = nn.ReLU - elif v == 'r6': - value = nn.ReLU6 - elif v == 'hs': - value = HardSwish - elif v == 'sw': - value = Swish - else: - continue - options[key] = value - else: - # all numeric options - splits = re.split(r'(\d.*)', op) - if len(splits) >= 2: - key, value = splits[:2] - options[key] = value - - # if act_layer is None, the model default (passed to model init) will be used - act_layer = options['n'] if 'n' in options else None - exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 - pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 - fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def - - num_repeat = int(options['r']) - # each type of block has different valid arguments, fill accordingly - if block_type == 'ir': - block_args = dict( - block_type=block_type, - dw_kernel_size=_parse_ksize(options['k']), - exp_kernel_size=exp_kernel_size, - pw_kernel_size=pw_kernel_size, - out_chs=int(options['c']), - exp_ratio=float(options['e']), - se_ratio=float(options['se']) if 'se' in options else None, - stride=int(options['s']), - act_layer=act_layer, - noskip=noskip, - num_experts=int(options['cc']) if 'cc' in options else 0 - ) - elif block_type == 'ds' or block_type == 'dsa': - block_args = dict( - block_type=block_type, - dw_kernel_size=_parse_ksize(options['k']), - pw_kernel_size=pw_kernel_size, - out_chs=int(options['c']), - se_ratio=float(options['se']) if 'se' in options else None, - stride=int(options['s']), - act_layer=act_layer, - pw_act=block_type == 'dsa', - noskip=block_type == 'dsa' or noskip, - ) - elif block_type == 'er': - block_args = dict( - block_type=block_type, - exp_kernel_size=_parse_ksize(options['k']), - pw_kernel_size=pw_kernel_size, - out_chs=int(options['c']), - exp_ratio=float(options['e']), - fake_in_chs=fake_in_chs, - se_ratio=float(options['se']) if 'se' in options else None, - stride=int(options['s']), - act_layer=act_layer, - noskip=noskip, - ) - elif block_type == 'cn': - block_args = dict( - block_type=block_type, - kernel_size=int(options['k']), - out_chs=int(options['c']), - stride=int(options['s']), - act_layer=act_layer, - ) - else: - assert False, 'Unknown block type (%s)' % block_type - - return block_args, num_repeat - - -def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): - """ Per-stage depth scaling - Scales the block repeats in each stage. This depth scaling impl maintains - compatibility with the EfficientNet scaling method, while allowing sensible - scaling for other models that may have multiple block arg definitions in each stage. - """ - - # We scale the total repeat count for each stage, there may be multiple - # block arg defs per stage so we need to sum. - num_repeat = sum(repeats) - if depth_trunc == 'round': - # Truncating to int by rounding allows stages with few repeats to remain - # proportionally smaller for longer. This is a good choice when stage definitions - # include single repeat stages that we'd prefer to keep that way as long as possible - num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) - else: - # The default for EfficientNet truncates repeats to int via 'ceil'. - # Any multiplier > 1.0 will result in an increased depth for every stage. - num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) - - # Proportionally distribute repeat count scaling to each block definition in the stage. - # Allocation is done in reverse as it results in the first block being less likely to be scaled. - # The first block makes less sense to repeat in most of the arch definitions. - repeats_scaled = [] - for r in repeats[::-1]: - rs = max(1, round((r / num_repeat * num_repeat_scaled))) - repeats_scaled.append(rs) - num_repeat -= r - num_repeat_scaled -= rs - repeats_scaled = repeats_scaled[::-1] - - # Apply the calculated scaling to each block arg in the stage - sa_scaled = [] - for ba, rep in zip(stack_args, repeats_scaled): - sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) - return sa_scaled - - -def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1): - arch_args = [] - for stack_idx, block_strings in enumerate(arch_def): - assert isinstance(block_strings, list) - stack_args = [] - repeats = [] - for block_str in block_strings: - assert isinstance(block_str, str) - ba, rep = _decode_block_str(block_str) - if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: - ba['num_experts'] *= experts_multiplier - stack_args.append(ba) - repeats.append(rep) - arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) - return arch_args - - -_USE_SWISH_OPT = True -if _USE_SWISH_OPT: - @torch.jit.script - def swish_jit_fwd(x): - return x.mul(torch.sigmoid(x)) - - - @torch.jit.script - def swish_jit_bwd(x, grad_output): - x_sigmoid = torch.sigmoid(x) - return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) - - - class SwishJitAutoFn(torch.autograd.Function): - """ torch.jit.script optimised Swish - Inspired by conversation btw Jeremy Howard & Adam Pazske - https://twitter.com/jeremyphoward/status/1188251041835315200 - """ - - @staticmethod - def forward(ctx, x): - ctx.save_for_backward(x) - return swish_jit_fwd(x) - - @staticmethod - def backward(ctx, grad_output): - x = ctx.saved_tensors[0] - return swish_jit_bwd(x, grad_output) - - - def swish(x, inplace=False): - # inplace ignored - return SwishJitAutoFn.apply(x) -else: - def swish(x, inplace=False): - return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) - - -def sigmoid(x, inplace=False): - return x.sigmoid_() if inplace else x.sigmoid() - - -def hard_swish(x, inplace=False): - if inplace: - return x.mul_(F.relu6(x + 3.) / 6.) - else: - return x * F.relu6(x + 3.) / 6. - - -def hard_sigmoid(x, inplace=False): - if inplace: - return x.add_(3.).clamp_(0., 6.).div_(6.) - else: - return F.relu6(x + 3.) / 6. - - -class _BlockBuilder: - """ Build Trunk Blocks - - This ended up being somewhat of a cross between - https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py - and - https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py - - """ - def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, - output_stride=32, pad_type='', act_layer=None, se_gate_fn=sigmoid, se_reduce_mid=False, - norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, drop_connect_rate=0., feature_location='', - verbose=False): - self.channel_multiplier = channel_multiplier - self.channel_divisor = channel_divisor - self.channel_min = channel_min - self.output_stride = output_stride - self.pad_type = pad_type - self.act_layer = act_layer - self.se_gate_fn = se_gate_fn - self.se_reduce_mid = se_reduce_mid - self.norm_layer = norm_layer - self.norm_kwargs = norm_kwargs - self.drop_connect_rate = drop_connect_rate - self.feature_location = feature_location - assert feature_location in ('pre_pwl', 'post_exp', '') - self.verbose = verbose - - # state updated during build, consumed by model - self.in_chs = None - self.features = OrderedDict() - - def _round_channels(self, chs): - return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) - - def _make_block(self, ba, block_idx, block_count): - drop_connect_rate = self.drop_connect_rate * block_idx / block_count - bt = ba.pop('block_type') - ba['in_chs'] = self.in_chs - ba['out_chs'] = self._round_channels(ba['out_chs']) - if 'fake_in_chs' in ba and ba['fake_in_chs']: - # FIXME this is a hack to work around mismatch in origin impl input filters - ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) - ba['norm_layer'] = self.norm_layer - ba['norm_kwargs'] = self.norm_kwargs - ba['pad_type'] = self.pad_type - # block act fn overrides the model default - ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer - assert ba['act_layer'] is not None - if bt == 'ir': - ba['drop_connect_rate'] = drop_connect_rate - ba['se_gate_fn'] = self.se_gate_fn - ba['se_reduce_mid'] = self.se_reduce_mid - if self.verbose: - logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba))) - block = InvertedResidual(**ba) - elif bt == 'ds' or bt == 'dsa': - ba['drop_connect_rate'] = drop_connect_rate - if self.verbose: - logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba))) - block = DepthwiseSeparableConv(**ba) - elif bt == 'er': - ba['drop_connect_rate'] = drop_connect_rate - ba['se_gate_fn'] = self.se_gate_fn - ba['se_reduce_mid'] = self.se_reduce_mid - if self.verbose: - logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba))) - block = EdgeResidual(**ba) - elif bt == 'cn': - if self.verbose: - logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba))) - block = ConvBnAct(**ba) - else: - assert False, 'Uknkown block type (%s) while building model.' % bt - self.in_chs = ba['out_chs'] # update in_chs for arg of next block - - return block - - def __call__(self, in_chs, model_block_args): - """ Build the blocks - Args: - in_chs: Number of input-channels passed to first block - model_block_args: A list of lists, outer list defines stages, inner - list contains strings defining block configuration(s) - Return: - List of block stacks (each stack wrapped in nn.Sequential) - """ - if self.verbose: - logging.info('Building model trunk with %d stages...' % len(model_block_args)) - self.in_chs = in_chs - total_block_count = sum([len(x) for x in model_block_args]) - total_block_idx = 0 - current_stride = 2 - current_dilation = 1 - feature_idx = 0 - stages = [] - # outer list of block_args defines the stacks ('stages' by some conventions) - for stage_idx, stage_block_args in enumerate(model_block_args): - last_stack = stage_idx == (len(model_block_args) - 1) - if self.verbose: - logging.info('Stack: {}'.format(stage_idx)) - assert isinstance(stage_block_args, list) - - blocks = [] - # each stack (stage) contains a list of block arguments - for block_idx, block_args in enumerate(stage_block_args): - last_block = block_idx == (len(stage_block_args) - 1) - extract_features = '' # No features extracted - if self.verbose: - logging.info(' Block: {}'.format(block_idx)) - - # Sort out stride, dilation, and feature extraction details - assert block_args['stride'] in (1, 2) - if block_idx >= 1: - # only the first block in any stack can have a stride > 1 - block_args['stride'] = 1 - - do_extract = False - if self.feature_location == 'pre_pwl': - if last_block: - next_stage_idx = stage_idx + 1 - if next_stage_idx >= len(model_block_args): - do_extract = True - else: - do_extract = model_block_args[next_stage_idx][0]['stride'] > 1 - elif self.feature_location == 'post_exp': - if block_args['stride'] > 1 or (last_stack and last_block) : - do_extract = True - if do_extract: - extract_features = self.feature_location - - next_dilation = current_dilation - if block_args['stride'] > 1: - next_output_stride = current_stride * block_args['stride'] - if next_output_stride > self.output_stride: - next_dilation = current_dilation * block_args['stride'] - block_args['stride'] = 1 - if self.verbose: - logging.info(' Converting stride to dilation to maintain output_stride=={}'.format( - self.output_stride)) - else: - current_stride = next_output_stride - block_args['dilation'] = current_dilation - if next_dilation != current_dilation: - current_dilation = next_dilation - - # create the block - block = self._make_block(block_args, total_block_idx, total_block_count) - blocks.append(block) - - # stash feature module name and channel info for model feature extraction - if extract_features: - feature_module = block.feature_module(extract_features) - if feature_module: - feature_module = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_module - feature_channels = block.feature_channels(extract_features) - self.features[feature_idx] = dict( - name=feature_module, - num_chs=feature_channels - ) - feature_idx += 1 - - total_block_idx += 1 # incr global block idx (across all stacks) - stages.append(nn.Sequential(*blocks)) - return stages - - -def _init_weight_goog(m): - # weight init as per Tensorflow Official impl - # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # fan-out - m.weight.data.normal_(0, math.sqrt(2.0 / n)) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1.0) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - n = m.weight.size(0) # fan-out - init_range = 1.0 / math.sqrt(n) - m.weight.data.uniform_(-init_range, init_range) - m.bias.data.zero_() - - -def _init_weight_default(m): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1.0) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') - - -def drop_connect(inputs, training=False, drop_connect_rate=0.): - """Apply drop connect.""" - if not training: - return inputs - - keep_prob = 1 - drop_connect_rate - random_tensor = keep_prob + torch.rand( - (inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device) - random_tensor.floor_() # binarize - output = inputs.div(keep_prob) * random_tensor - return output - - -class ChannelShuffle(nn.Module): - # FIXME haven't used yet - def __init__(self, groups): - super(ChannelShuffle, self).__init__() - self.groups = groups - - def forward(self, x): - """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" - N, C, H, W = x.size() - g = self.groups - assert C % g == 0, "Incompatible group size {} for input channel {}".format( - g, C - ) - return ( - x.view(N, g, int(C / g), H, W) - .permute(0, 2, 1, 3, 4) - .contiguous() - .view(N, C, H, W) - ) - - -class SqueezeExcite(nn.Module): - def __init__(self, in_chs, reduce_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid): - super(SqueezeExcite, self).__init__() - self.gate_fn = gate_fn - reduced_chs = reduce_chs or in_chs - self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) - self.act1 = act_layer(inplace=True) - self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) - - def forward(self, x): - # NOTE adaptiveavgpool can be used here, but seems to cause issues with NVIDIA AMP performance - x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) - x_se = self.conv_reduce(x_se) - x_se = self.act1(x_se) - x_se = self.conv_expand(x_se) - x = x * self.gate_fn(x_se) - return x - -class ConvBnAct(nn.Module): - def __init__(self, in_chs, out_chs, kernel_size, - stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT,): - super(ConvBnAct, self).__init__() - assert stride in [1, 2] - self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type) - self.bn1 = norm_layer(out_chs, **norm_kwargs) - self.act1 = act_layer(inplace=True) - - def feature_module(self, location): - return 'act1' +_DEBUG = False - def feature_channels(self, location): - return self.conv.out_channels - def forward(self, x): - x = self.conv(x) - x = self.bn1(x) - x = self.act1(x) - return x - - -class EdgeResidual(nn.Module): - """ Residual block with expansion convolution followed by pointwise-linear w/ stride""" - - def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, - stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, - se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid, - norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, drop_connect_rate=0.): - super(EdgeResidual, self).__init__() - mid_chs = int(fake_in_chs * exp_ratio) if fake_in_chs > 0 else int(in_chs * exp_ratio) - self.has_se = se_ratio is not None and se_ratio > 0. - self.has_residual = (in_chs == out_chs and stride == 1) and not noskip - self.drop_connect_rate = drop_connect_rate - - # Expansion convolution - self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) - self.bn1 = norm_layer(mid_chs, **norm_kwargs) - self.act1 = act_layer(inplace=True) - - # Squeeze-and-excitation - if self.has_se: - se_base_chs = mid_chs if se_reduce_mid else in_chs - self.se = SqueezeExcite( - mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_layer=act_layer, gate_fn=se_gate_fn) - - # Point-wise linear projection - self.conv_pwl = select_conv2d( - mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type) - self.bn2 = norm_layer(out_chs, **norm_kwargs) - - def feature_module(self, location): - if location == 'post_exp': - return 'act1' - return 'conv_pwl' - - def feature_channels(self, location): - if location == 'post_exp': - return self.conv_exp.out_channels - # location == 'pre_pw' - return self.conv_pwl.in_channels - - def forward(self, x): - residual = x - - # Expansion convolution - x = self.conv_exp(x) - x = self.bn1(x) - x = self.act1(x) - - # Squeeze-and-excitation - if self.has_se: - x = self.se(x) - - # Point-wise linear projection - x = self.conv_pwl(x) - x = self.bn2(x) - - if self.has_residual: - if self.drop_connect_rate > 0.: - x = drop_connect(x, self.training, self.drop_connect_rate) - x += residual - - return x - - -class DepthwiseSeparableConv(nn.Module): - """ DepthwiseSeparable block - Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion - factor of 1.0. This is an alternative to having a IR with an optional first pw conv. - """ - def __init__(self, in_chs, out_chs, dw_kernel_size=3, - stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, - pw_kernel_size=1, pw_act=False, se_ratio=0., se_gate_fn=sigmoid, - norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, drop_connect_rate=0.): - super(DepthwiseSeparableConv, self).__init__() - assert stride in [1, 2] - self.has_se = se_ratio is not None and se_ratio > 0. - self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip - self.has_pw_act = pw_act # activation after point-wise conv - self.drop_connect_rate = drop_connect_rate - - self.conv_dw = select_conv2d( - in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True) - self.bn1 = norm_layer(in_chs, **norm_kwargs) - self.act1 = act_layer(inplace=True) - - # Squeeze-and-excitation - if self.has_se: - self.se = SqueezeExcite( - in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), act_layer=act_layer, gate_fn=se_gate_fn) - - self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) - self.bn2 = norm_layer(out_chs, **norm_kwargs) - self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() - - def feature_module(self, location): - # no expansion in this block, pre pw only feature extraction point - return 'conv_pw' - - def feature_channels(self, location): - return self.conv_pw.in_channels - - def forward(self, x): - residual = x - - x = self.conv_dw(x) - x = self.bn1(x) - x = self.act1(x) - - if self.has_se: - x = self.se(x) - - x = self.conv_pw(x) - x = self.bn2(x) - x = self.act2(x) - - if self.has_residual: - if self.drop_connect_rate > 0.: - x = drop_connect(x, self.training, self.drop_connect_rate) - x += residual - return x - - -class InvertedResidual(nn.Module): - """ Inverted residual block w/ optional SE and CondConv routing""" - - def __init__(self, in_chs, out_chs, dw_kernel_size=3, - stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, - exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, - se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid, - norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, - num_experts=0, drop_connect_rate=0.): - super(InvertedResidual, self).__init__() - mid_chs = int(in_chs * exp_ratio) - self.has_se = se_ratio is not None and se_ratio > 0. - self.has_residual = (in_chs == out_chs and stride == 1) and not noskip - self.drop_connect_rate = drop_connect_rate - - self.num_experts = num_experts - extra_args = dict() - if num_experts > 0: - extra_args = dict(num_experts=self.num_experts) - self.routing_fn = nn.Linear(in_chs, self.num_experts) - self.routing_act = torch.sigmoid - - # Point-wise expansion - self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **extra_args) - self.bn1 = norm_layer(mid_chs, **norm_kwargs) - self.act1 = act_layer(inplace=True) - - # Depth-wise convolution - self.conv_dw = select_conv2d( - mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, - padding=pad_type, depthwise=True, **extra_args) - self.bn2 = norm_layer(mid_chs, **norm_kwargs) - self.act2 = act_layer(inplace=True) - - # Squeeze-and-excitation - if self.has_se: - se_base_chs = mid_chs if se_reduce_mid else in_chs - self.se = SqueezeExcite( - mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_layer=act_layer, gate_fn=se_gate_fn) - - # Point-wise linear projection - self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **extra_args) - self.bn3 = norm_layer(out_chs, **norm_kwargs) - - def feature_module(self, location): - if location == 'post_exp': - return 'act1' - return 'conv_pwl' - - def feature_channels(self, location): - if location == 'post_exp': - return self.conv_pw.out_channels - # location == 'pre_pw' - return self.conv_pwl.in_channels - - def forward(self, x): - residual = x - - conv_pw, conv_dw, conv_pwl = self.conv_pw, self.conv_dw, self.conv_pwl - if self.num_experts > 0: - pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) - routing_weights = self.routing_act(self.routing_fn(pooled_inputs)) - conv_pw = partial(self.conv_pw, routing_weights=routing_weights) - conv_dw = partial(self.conv_dw, routing_weights=routing_weights) - conv_pwl = partial(self.conv_pwl, routing_weights=routing_weights) - - # Point-wise expansion - x = conv_pw(x) - x = self.bn1(x) - x = self.act1(x) - - # Depth-wise convolution - x = conv_dw(x) - x = self.bn2(x) - x = self.act2(x) - - # Squeeze-and-excitation - if self.has_se: - x = self.se(x) - - # Point-wise linear projection - x = conv_pwl(x) - x = self.bn3(x) - - if self.has_residual: - if self.drop_connect_rate > 0.: - x = drop_connect(x, self.training, self.drop_connect_rate) - x += residual - - return x +class EfficientNet(nn.Module): + """ (Generic) EfficientNet + A flexible and performant PyTorch implementation of efficient network architectures, including: + * EfficientNet B0-B8 + * EfficientNet-EdgeTPU + * EfficientNet-CondConv + * MixNet S, M, L, XL + * MnasNet A1, B1, and small + * FBNet C + * Single-Path NAS Pixel1 -class _GenEfficientNet(nn.Module): - """ Generic EfficientNet Base """ - def __init__(self, block_args, in_chans=3, stem_size=32, + def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None, - output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., - se_gate_fn=sigmoid, se_reduce_mid=False, norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, - feature_location='pre_pwl'): - super(_GenEfficientNet, self).__init__() + pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + global_pool='avg', weight_init='goog'): + super(EfficientNet, self).__init__() + norm_kwargs = norm_kwargs or {} + + self.num_classes = num_classes + self.num_features = num_features self.drop_rate = drop_rate self._in_chs = in_chans # Stem - stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) self._in_chs = stem_size # Middle stages (IR/ER/DS Blocks) - builder = _BlockBuilder( - channel_multiplier, channel_divisor, channel_min, - output_stride, pad_type, act_layer, se_gate_fn, se_reduce_mid, - norm_layer, norm_kwargs, drop_connect_rate, feature_location=feature_location, verbose=_DEBUG) + builder = EfficientNetBuilder( + channel_multiplier, channel_divisor, channel_min, 32, pad_type, act_layer, se_kwargs, + norm_layer, norm_kwargs, drop_connect_rate, verbose=_DEBUG) self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) self.feature_info = builder.features self._in_chs = builder.in_chs - def as_sequential(self): - layers = [self.conv_stem, self.bn1, self.act1] - layers.extend(self.blocks) - return nn.Sequential(*layers) - - def forward(self, x): - x = self.conv_stem(x) - x = self.bn1(x) - x = self.act1(x) - x = self.blocks(x) - return x - - -class GenEfficientNet(_GenEfficientNet): - """ Generic EfficientNet - - An implementation of efficient network architectures, in many cases mobile optimized networks: - * MobileNet-V1 - * MobileNet-V2 - * MobileNet-V3 - * MnasNet A1, B1, and small - * FBNet A, B, and C - * ChamNet (arch details are murky) - * Single-Path NAS Pixel1 - * EfficientNet B0-B7 - * MixNet S, M, L - """ - - def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, - channel_multiplier=1.0, channel_divisor=8, channel_min=None, - pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., - se_gate_fn=sigmoid, se_reduce_mid=False, - norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, - global_pool='avg', head_conv='default', weight_init='goog'): - - self.num_classes = num_classes - self.num_features = num_features - super(GenEfficientNet, self).__init__( # FIXME it would be nice if Python made this nicer - block_args, in_chans=in_chans, stem_size=stem_size, - pad_type=pad_type, act_layer=act_layer, drop_rate=drop_rate, drop_connect_rate=drop_connect_rate, - channel_multiplier=channel_multiplier, channel_divisor=channel_divisor, channel_min=channel_min, - se_gate_fn=se_gate_fn, se_reduce_mid=se_reduce_mid, norm_layer=norm_layer, norm_kwargs=norm_kwargs) - # Head + Pooling - self.conv_head = None - self.global_pool = None - self.act2 = None - self.forward_head = None - self.head_conv = head_conv - if head_conv == 'efficient': - self._create_head_efficient(global_pool, pad_type, act_layer) - elif head_conv == 'default': - self._create_head_default(global_pool, pad_type, act_layer, norm_layer, norm_kwargs) + self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type) + self.bn2 = norm_layer(self.num_features, **norm_kwargs) + self.act2 = act_layer(inplace=True) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) # Classifier self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes) for m in self.modules(): if weight_init == 'goog': - _init_weight_goog(m) + efficientnet_init_goog(m) else: - _init_weight_default(m) - - def _create_head_default(self, global_pool, pad_type, act_layer, norm_layer, norm_kwargs): - self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type) - self.bn2 = norm_layer(self.num_features, **norm_kwargs) - self.act2 = act_layer(inplace=True) - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - - def _create_head_efficient(self, global_pool, pad_type, act_layer): - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type) - self.act2 = act_layer(inplace=True) - - def _forward_head_default(self, x): - x = self.conv_head(x) - x = self.bn2(x) - x = self.act2(x) - return x - - def _forward_head_efficient(self, x): - x = self.global_pool(x) - x = self.conv_head(x) - x = self.act2(x) - return x + efficientnet_init_default(m) def as_sequential(self): layers = [self.conv_stem, self.bn1, self.act1] layers.extend(self.blocks) - if self.head_conv == 'efficient': - layers.extend([self.global_pool, self.conv_head, self.act2]) - else: - layers.extend([self.conv_head, self.bn2, self.act2]) - if self.global_pool is not None: - layers.append(self.global_pool) - layers.extend([Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + layers.extend([self.conv_head, self.bn2, self.act2, self.global_pool]) + layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) return nn.Sequential(*layers) def get_classifier(self): @@ -1075,86 +266,80 @@ class GenEfficientNet(_GenEfficientNet): self.classifier = None def forward_features(self, x): - x = super(GenEfficientNet, self).forward(x) - if self.head_conv == 'efficient': - x = self._forward_head_efficient(x) - elif self.head_conv == 'default': - x = self._forward_head_default(x) + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.conv_head(x) + x = self.bn2(x) + x = self.act2(x) return x def forward(self, x): x = self.forward_features(x) - if self.global_pool is not None and x.shape[-1] > 1 or x.shape[-2] > 1: - x = self.global_pool(x) + x = self.global_pool(x) x = x.flatten(1) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) return self.classifier(x) -class GenEfficientNetFeatures(_GenEfficientNet): - """ Generic EfficientNet Feature Extractor +class EfficientNetFeatures(nn.Module): + """ EfficientNet Feature Extractor + + A work-in-progress feature extraction module for EfficientNet, to use as a backbone for segmentation + and object detection models. """ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl', in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None, output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., - se_gate_fn=sigmoid, se_reduce_mid=False, norm_layer=nn.BatchNorm2d, norm_kwargs=_BN_ARGS_PT, - weight_init='goog'): + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'): + super(EfficientNetFeatures, self).__init__() + norm_kwargs = norm_kwargs or {} + + # TODO only create stages needed, currently all stages are created regardless of out_indices + num_stages = max(out_indices) + 1 - # validate and modify block arguments and out indices for feature extraction - num_stages = max(out_indices) + 1 # FIXME reduce num stages created if not needed - #assert len(block_args) >= num_stages - 1 - #block_args = block_args[:num_stages - 1] self.out_indices = out_indices + self.drop_rate = drop_rate + self._in_chs = in_chans + + # Stem + stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + self._in_chs = stem_size - # FIXME it would be nice if Python made this nicer without using kwargs and erasing IDE hints, etc - super(GenEfficientNetFeatures, self).__init__( - block_args, in_chans=in_chans, stem_size=stem_size, - output_stride=output_stride, pad_type=pad_type, act_layer=act_layer, - drop_rate=drop_rate, drop_connect_rate=drop_connect_rate, feature_location=feature_location, - channel_multiplier=channel_multiplier, channel_divisor=channel_divisor, channel_min=channel_min, - se_gate_fn=se_gate_fn, se_reduce_mid=se_reduce_mid, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs, + norm_layer, norm_kwargs, drop_connect_rate, feature_location=feature_location, verbose=_DEBUG) + self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) + self.feature_info = builder.features # builder provides info about feature channels for each block + self._in_chs = builder.in_chs for m in self.modules(): if weight_init == 'goog': - _init_weight_goog(m) + efficientnet_init_goog(m) else: - _init_weight_default(m) + efficientnet_init_default(m) if _DEBUG: for k, v in self.feature_info.items(): print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs'])) + + # Register feature extraction hooks with FeatureHooks helper hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward' hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices] - self._feature_outputs = None - self._register_hooks(hooks) - - def _collect_output_hook(self, name, *args): - x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre - if isinstance(x, tuple): - x = x[0] # unwrap input tuple - self._feature_outputs[x.device][name] = x - - def _get_output(self, device): - output = tuple(self._feature_outputs[device].values())[::-1] - self._feature_outputs[device] = OrderedDict() - return output - - def _register_hooks(self, hooks): - # setup feature hooks - modules = {k: v for k, v in self.named_modules()} - for h in hooks: - hook_name = h['name'] - m = modules[hook_name] - hook_fn = partial(self._collect_output_hook, hook_name) - if h['type'] == 'forward_pre': - m.register_forward_pre_hook(hook_fn) - else: - m.register_forward_hook(hook_fn) - self._feature_outputs = defaultdict(OrderedDict) + self.feature_hooks = FeatureHooks(hooks, self.named_modules()) def feature_channels(self, idx=None): + """ Feature Channel Shortcut + Returns feature channel count for each output index if idx == None. If idx is an integer, will + return feature channel count for that feature block index (independent of out_indices setting). + """ if isinstance(idx, int): return self.feature_info[idx]['num_chs'] return [self.feature_info[i]['num_chs'] for i in self.out_indices] @@ -1164,7 +349,7 @@ class GenEfficientNetFeatures(_GenEfficientNet): x = self.bn1(x) x = self.act1(x) self.blocks(x) - return self._get_output(x.device) + return self.feature_hooks.get_output(x.device) def _create_model(model_kwargs, default_cfg, pretrained=False): @@ -1173,10 +358,10 @@ def _create_model(model_kwargs, default_cfg, pretrained=False): model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_features', 0) model_kwargs.pop('head_conv', None) - model_class = GenEfficientNetFeatures + model_class = EfficientNetFeatures else: load_strict = True - model_class = GenEfficientNet + model_class = EfficientNet model = model_class(**model_kwargs) model.default_cfg = default_cfg @@ -1216,10 +401,10 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs) ['ir_r1_k3_s1_e6_c320'], ] model_kwargs = dict( - block_args=_decode_arch_def(arch_def), + block_args=decode_arch_def(arch_def), stem_size=32, channel_multiplier=channel_multiplier, - norm_kwargs=_resolve_bn_args(kwargs), + norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, default_cfgs[variant], pretrained) @@ -1252,10 +437,10 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs) ['ir_r1_k3_s1_e6_c320_noskip'] ] model_kwargs = dict( - block_args=_decode_arch_def(arch_def), + block_args=decode_arch_def(arch_def), stem_size=32, channel_multiplier=channel_multiplier, - norm_kwargs=_resolve_bn_args(kwargs), + norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, default_cfgs[variant], pretrained) @@ -1281,36 +466,10 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar ['ir_r1_k3_s1_e6_c144'] ] model_kwargs = dict( - block_args=_decode_arch_def(arch_def), + block_args=decode_arch_def(arch_def), stem_size=8, channel_multiplier=channel_multiplier, - norm_kwargs=_resolve_bn_args(kwargs), - **kwargs - ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) - return model - - -def _gen_mobilenet_v1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): - """ Generate MobileNet-V1 network - Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py - Paper: https://arxiv.org/abs/1801.04381 - """ - arch_def = [ - ['dsa_r1_k3_s1_c64'], - ['dsa_r2_k3_s2_c128'], - ['dsa_r2_k3_s2_c256'], - ['dsa_r6_k3_s2_c512'], - ['dsa_r2_k3_s2_c1024'], - ] - model_kwargs = dict( - block_args=_decode_arch_def(arch_def), - stem_size=32, - num_features=1024, - channel_multiplier=channel_multiplier, - norm_kwargs=_resolve_bn_args(kwargs), - act_layer=nn.ReLU6, - head_conv='none', + norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, default_cfgs[variant], pretrained) @@ -1332,10 +491,10 @@ def _gen_mobilenet_v2(variant, channel_multiplier=1.0, pretrained=False, **kwarg ['ir_r1_k3_s1_e6_c320'], ] model_kwargs = dict( - block_args=_decode_arch_def(arch_def), + block_args=decode_arch_def(arch_def), stem_size=32, channel_multiplier=channel_multiplier, - norm_kwargs=_resolve_bn_args(kwargs), + norm_kwargs=resolve_bn_args(kwargs), act_layer=nn.ReLU6, **kwargs ) @@ -1343,104 +502,6 @@ def _gen_mobilenet_v2(variant, channel_multiplier=1.0, pretrained=False, **kwarg return model -def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): - """Creates a MobileNet-V3 model. - - Ref impl: ? - Paper: https://arxiv.org/abs/1905.02244 - - Args: - channel_multiplier: multiplier to number of channels per layer. - """ - arch_def = [ - # stage 0, 112x112 in - ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu - # stage 1, 112x112 in - ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu - # stage 2, 56x56 in - ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu - # stage 3, 28x28 in - ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish - # stage 4, 14x14in - ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish - # stage 5, 14x14in - ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish - # stage 6, 7x7 in - ['cn_r1_k1_s1_c960'], # hard-swish - ] - model_kwargs = dict( - block_args=_decode_arch_def(arch_def), - stem_size=16, - channel_multiplier=channel_multiplier, - norm_kwargs=_resolve_bn_args(kwargs), - act_layer=HardSwish, - se_gate_fn=hard_sigmoid, - se_reduce_mid=True, - head_conv='efficient', - **kwargs, - ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) - return model - - -def _gen_chamnet_v1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): - """ Generate Chameleon Network (ChamNet) - - Paper: https://arxiv.org/abs/1812.08934 - Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py - - FIXME: this a bit of an educated guess based on trunkd def in maskrcnn_benchmark - """ - arch_def = [ - ['ir_r1_k3_s1_e1_c24'], - ['ir_r2_k7_s2_e4_c48'], - ['ir_r5_k3_s2_e7_c64'], - ['ir_r7_k5_s2_e12_c56'], - ['ir_r5_k3_s1_e8_c88'], - ['ir_r4_k3_s2_e7_c152'], - ['ir_r1_k3_s1_e10_c104'], - ] - model_kwargs = dict( - block_args=_decode_arch_def(arch_def), - stem_size=32, - num_features=1280, # no idea what this is? try mobile/mnasnet default? - channel_multiplier=channel_multiplier, - norm_kwargs=_resolve_bn_args(kwargs), - **kwargs - ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) - return model - - -def _gen_chamnet_v2(variant, channel_multiplier=1.0, pretrained=False, **kwargs): - """ Generate Chameleon Network (ChamNet) - - Paper: https://arxiv.org/abs/1812.08934 - Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py - - FIXME: this a bit of an educated guess based on trunk def in maskrcnn_benchmark - """ - arch_def = [ - ['ir_r1_k3_s1_e1_c24'], - ['ir_r4_k5_s2_e8_c32'], - ['ir_r6_k7_s2_e5_c48'], - ['ir_r3_k5_s2_e9_c56'], - ['ir_r6_k3_s1_e6_c56'], - ['ir_r6_k3_s2_e2_c152'], - ['ir_r1_k3_s1_e6_c112'], - ] - model_kwargs = dict( - block_args=_decode_arch_def(arch_def), - stem_size=32, - num_features=1280, # no idea what this is? try mobile/mnasnet default? - channel_multiplier=channel_multiplier, - norm_kwargs=_resolve_bn_args(kwargs), - **kwargs - ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) - return model - - def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """ FBNet-C @@ -1460,11 +521,11 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): ['ir_r1_k3_s1_e6_c352'], ] model_kwargs = dict( - block_args=_decode_arch_def(arch_def), + block_args=decode_arch_def(arch_def), stem_size=16, num_features=1984, # paper suggests this, but is not 100% clear channel_multiplier=channel_multiplier, - norm_kwargs=_resolve_bn_args(kwargs), + norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, default_cfgs[variant], pretrained) @@ -1496,10 +557,10 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): ['ir_r1_k3_s1_e6_c320_noskip'] ] model_kwargs = dict( - block_args=_decode_arch_def(arch_def), + block_args=decode_arch_def(arch_def), stem_size=32, channel_multiplier=channel_multiplier, - norm_kwargs=_resolve_bn_args(kwargs), + norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, default_cfgs[variant], pretrained) @@ -1522,6 +583,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), Args: channel_multiplier: multiplier to number of channels per layer @@ -1538,12 +600,12 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre ['ir_r1_k3_s1_e6_c320_se0.25'], ] model_kwargs = dict( - block_args=_decode_arch_def(arch_def, depth_multiplier), - num_features=_round_channels(1280, channel_multiplier, 8, None), + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), stem_size=32, channel_multiplier=channel_multiplier, - norm_kwargs=_resolve_bn_args(kwargs), act_layer=Swish, + norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) model = _create_model(model_kwargs, default_cfgs[variant], pretrained) @@ -1567,11 +629,11 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0 ['ir_r2_k5_s2_e8_c192'], ] model_kwargs = dict( - block_args=_decode_arch_def(arch_def, depth_multiplier), - num_features=_round_channels(1280, channel_multiplier, 8, None), + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), stem_size=32, channel_multiplier=channel_multiplier, - norm_kwargs=_resolve_bn_args(kwargs), + norm_kwargs=resolve_bn_args(kwargs), act_layer=nn.ReLU, **kwargs, ) @@ -1597,11 +659,11 @@ def _gen_efficientnet_condconv( # NOTE unlike official impl, this one uses `cc` option where x is the base number of experts for each stage and # the expert_multiplier increases that on a per-model basis as with depth/channel multipliers model_kwargs = dict( - block_args=_decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), - num_features=_round_channels(1280, channel_multiplier, 8, None), + block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), stem_size=32, channel_multiplier=channel_multiplier, - norm_kwargs=_resolve_bn_args(kwargs), + norm_kwargs=resolve_bn_args(kwargs), act_layer=Swish, **kwargs, ) @@ -1631,11 +693,11 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): # 7x7 ] model_kwargs = dict( - block_args=_decode_arch_def(arch_def), + block_args=decode_arch_def(arch_def), num_features=1536, stem_size=16, channel_multiplier=channel_multiplier, - norm_kwargs=_resolve_bn_args(kwargs), + norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, default_cfgs[variant], pretrained) @@ -1664,11 +726,11 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai # 7x7 ] model_kwargs = dict( - block_args=_decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), + block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), num_features=1536, stem_size=24, channel_multiplier=channel_multiplier, - norm_kwargs=_resolve_bn_args(kwargs), + norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, default_cfgs[variant], pretrained) @@ -1750,13 +812,6 @@ def mnasnet_small(pretrained=False, **kwargs): return model -@register_model -def mobilenetv1_100(pretrained=False, **kwargs): - """ MobileNet V1 """ - model = _gen_mobilenet_v1('mobilenetv1_100', 1.0, pretrained=pretrained, **kwargs) - return model - - @register_model def mobilenetv2_100(pretrained=False, **kwargs): """ MobileNet V2 """ @@ -1764,54 +819,16 @@ def mobilenetv2_100(pretrained=False, **kwargs): return model -@register_model -def mobilenetv3_050(pretrained=False, **kwargs): - """ MobileNet V3 """ - model = _gen_mobilenet_v3('mobilenetv3_050', 0.5, pretrained=pretrained, **kwargs) - return model - - -@register_model -def mobilenetv3_075(pretrained=False, **kwargs): - """ MobileNet V3 """ - model = _gen_mobilenet_v3('mobilenetv3_075', 0.75, pretrained=pretrained, **kwargs) - return model - - -@register_model -def mobilenetv3_100(pretrained=False, **kwargs): - """ MobileNet V3 """ - if pretrained: - # pretrained model trained with non-default BN epsilon - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT - model = _gen_mobilenet_v3('mobilenetv3_100', 1.0, pretrained=pretrained, **kwargs) - return model - - @register_model def fbnetc_100(pretrained=False, **kwargs): """ FBNet-C """ if pretrained: # pretrained model trained with non-default BN epsilon - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs) return model -@register_model -def chamnetv1_100(pretrained=False, **kwargs): - """ ChamNet """ - model = _gen_chamnet_v1('chamnetv1_100', 1.0, pretrained=pretrained, **kwargs) - return model - - -@register_model -def chamnetv2_100(pretrained=False, **kwargs): - """ ChamNet """ - model = _gen_chamnet_v2('chamnetv2_100', 1.0, pretrained=pretrained, **kwargs) - return model - - @register_model def spnasnet_100(pretrained=False, **kwargs): """ Single-Path NAS Pixel1""" @@ -1957,7 +974,7 @@ def efficientnet_cc_b1_8e(pretrained=False, **kwargs): @register_model def tf_efficientnet_b0(pretrained=False, **kwargs): """ EfficientNet-B0. Tensorflow compatible variant """ - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) @@ -1967,7 +984,7 @@ def tf_efficientnet_b0(pretrained=False, **kwargs): @register_model def tf_efficientnet_b1(pretrained=False, **kwargs): """ EfficientNet-B1. Tensorflow compatible variant """ - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) @@ -1977,7 +994,7 @@ def tf_efficientnet_b1(pretrained=False, **kwargs): @register_model def tf_efficientnet_b2(pretrained=False, **kwargs): """ EfficientNet-B2. Tensorflow compatible variant """ - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) @@ -1987,7 +1004,7 @@ def tf_efficientnet_b2(pretrained=False, **kwargs): @register_model def tf_efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B3. Tensorflow compatible variant """ - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) @@ -1997,7 +1014,7 @@ def tf_efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs) @register_model def tf_efficientnet_b4(pretrained=False, **kwargs): """ EfficientNet-B4. Tensorflow compatible variant """ - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) @@ -2007,7 +1024,7 @@ def tf_efficientnet_b4(pretrained=False, **kwargs): @register_model def tf_efficientnet_b5(pretrained=False, **kwargs): """ EfficientNet-B5. Tensorflow compatible variant """ - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) @@ -2018,7 +1035,7 @@ def tf_efficientnet_b5(pretrained=False, **kwargs): def tf_efficientnet_b6(pretrained=False, **kwargs): """ EfficientNet-B6. Tensorflow compatible variant """ # NOTE for train, drop_rate should be 0.5 - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) @@ -2029,17 +1046,111 @@ def tf_efficientnet_b6(pretrained=False, **kwargs): def tf_efficientnet_b7(pretrained=False, **kwargs): """ EfficientNet-B7. Tensorflow compatible variant """ # NOTE for train, drop_rate should be 0.5 - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) return model +@register_model +def tf_efficientnet_b0_ap(pretrained=False, **kwargs): + """ EfficientNet-B0. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b1_ap(pretrained=False, **kwargs): + """ EfficientNet-B1. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b2_ap(pretrained=False, **kwargs): + """ EfficientNet-B2. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b3_ap(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ EfficientNet-B3. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b4_ap(pretrained=False, **kwargs): + """ EfficientNet-B4. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b5_ap(pretrained=False, **kwargs): + """ EfficientNet-B5. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b6_ap(pretrained=False, **kwargs): + """ EfficientNet-B6. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b7_ap(pretrained=False, **kwargs): + """ EfficientNet-B7. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b8_ap(pretrained=False, **kwargs): + """ EfficientNet-B7. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + + @register_model def tf_efficientnet_es(pretrained=False, **kwargs): """ EfficientNet-Edge Small. Tensorflow compatible variant """ - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_edge( 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) @@ -2049,7 +1160,7 @@ def tf_efficientnet_es(pretrained=False, **kwargs): @register_model def tf_efficientnet_em(pretrained=False, **kwargs): """ EfficientNet-Edge-Medium. Tensorflow compatible variant """ - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_edge( 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) @@ -2059,7 +1170,7 @@ def tf_efficientnet_em(pretrained=False, **kwargs): @register_model def tf_efficientnet_el(pretrained=False, **kwargs): """ EfficientNet-Edge-Large. Tensorflow compatible variant """ - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_edge( 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) @@ -2071,7 +1182,7 @@ def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B0 w/ 4 Experts. Tensorflow compatible variant """ # NOTE for train, drop_rate should be 0.2 #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_condconv( 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) @@ -2083,7 +1194,7 @@ def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B0 w/ 8 Experts. Tensorflow compatible variant """ # NOTE for train, drop_rate should be 0.2 #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_condconv( 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, @@ -2095,7 +1206,7 @@ def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """ # NOTE for train, drop_rate should be 0.2 #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_condconv( 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, @@ -2155,7 +1266,7 @@ def mixnet_xxl(pretrained=False, **kwargs): def tf_mixnet_s(pretrained=False, **kwargs): """Creates a MixNet Small model. Tensorflow compatible variant """ - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mixnet_s( 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) @@ -2166,7 +1277,7 @@ def tf_mixnet_s(pretrained=False, **kwargs): def tf_mixnet_m(pretrained=False, **kwargs): """Creates a MixNet Medium model. Tensorflow compatible variant """ - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mixnet_m( 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) @@ -2177,7 +1288,7 @@ def tf_mixnet_m(pretrained=False, **kwargs): def tf_mixnet_l(pretrained=False, **kwargs): """Creates a MixNet Large model. Tensorflow compatible variant """ - kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mixnet_m( 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) diff --git a/timm/models/layers.py b/timm/models/layers.py deleted file mode 100644 index c8e0a837..00000000 --- a/timm/models/layers.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def versiontuple(v): - return tuple(map(int, (v.split("."))))[:3] - - -if versiontuple(torch.__version__) >= versiontuple('1.2.0'): - Flatten = nn.Flatten -else: - class Flatten(nn.Module): - r""" - Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`. - Args: - start_dim: first dim to flatten (default = 1). - end_dim: last dim to flatten (default = -1). - Shape: - - Input: :math:`(N, *dims)` - - Output: :math:`(N, \prod *dims)` (for the default case). - """ - __constants__ = ['start_dim', 'end_dim'] - - def __init__(self, start_dim=1, end_dim=-1): - super(Flatten, self).__init__() - self.start_dim = start_dim - self.end_dim = end_dim - - def forward(self, input): - return input.flatten(self.start_dim, self.end_dim) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py new file mode 100644 index 00000000..13fd16e6 --- /dev/null +++ b/timm/models/mobilenetv3.py @@ -0,0 +1,439 @@ + +""" MobileNet V3 + +A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl. + +Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 + +Hacked together by Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F + +from .efficientnet_builder import * +from .activations import HardSwish, hard_sigmoid +from .registry import register_model +from .helpers import load_pretrained +from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .conv2d_layers import select_conv2d +from .feature_hooks import FeatureHooks +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD + +__all__ = ['MobileNetV3'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'mobilenetv3_large_075': _cfg(url=''), + 'mobilenetv3_large_100': _cfg(url=''), + 'mobilenetv3_rw': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', + interpolation='bicubic'), + 'tf_mobilenetv3_large_075': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_large_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_large_minimal_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_075': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_100': _cfg( + url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_minimal_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), +} + +_DEBUG = False + + +class MobileNetV3(nn.Module): + """ MobiletNet-V3 + + Based on my EfficientNet implementation and building blocks, this model utilizes the MobileNet-v3 specific + 'efficient head', where global pooling is done before the head convolution without a final batch-norm + layer before the classifier. + + Paper: https://arxiv.org/abs/1905.02244 + """ + + def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, + channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + global_pool='avg', weight_init='goog'): + super(MobileNetV3, self).__init__() + + self.num_classes = num_classes + self.num_features = num_features + self.drop_rate = drop_rate + self._in_chs = in_chans + + # Stem + stem_size = round_channels(stem_size, channel_multiplier) + self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + self._in_chs = stem_size + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs, + norm_layer, norm_kwargs, drop_connect_rate, verbose=_DEBUG) + self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) + self.feature_info = builder.features + self._in_chs = builder.in_chs + + # Head + Pooling + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias) + self.act2 = act_layer(inplace=True) + + # Classifier + self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes) + + for m in self.modules(): + if weight_init == 'goog': + efficientnet_init_goog(m) + else: + efficientnet_init_default(m) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([self.global_pool, self.conv_head, self.act2]) + layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.num_classes = num_classes + del self.classifier + if num_classes: + self.classifier = nn.Linear( + self.num_features * self.global_pool.feat_mult(), num_classes) + else: + self.classifier = None + + def forward_features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +class MobileNetV3Features(nn.Module): + """ MobileNetV3 Feature Extractor + + A work-in-progress feature extraction module for MobileNet-V3 to use as a backbone for segmentation + and object detection models. + """ + + def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl', + in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='', + act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'): + super(MobileNetV3Features, self).__init__() + norm_kwargs = norm_kwargs or {} + + # TODO only create stages needed, currently all stages are created regardless of out_indices + num_stages = max(out_indices) + 1 + + self.out_indices = out_indices + self.drop_rate = drop_rate + self._in_chs = in_chans + + # Stem + stem_size = round_channels(stem_size, channel_multiplier) + self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + self._in_chs = stem_size + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs, + norm_layer, norm_kwargs, drop_connect_rate, feature_location=feature_location, verbose=_DEBUG) + self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) + self.feature_info = builder.features # builder provides info about feature channels for each block + self._in_chs = builder.in_chs + + for m in self.modules(): + if weight_init == 'goog': + efficientnet_init_goog(m) + else: + efficientnet_init_default(m) + + if _DEBUG: + for k, v in self.feature_info.items(): + print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs'])) + + # Register feature extraction hooks with FeatureHooks helper + hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward' + hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices] + self.feature_hooks = FeatureHooks(hooks, self.named_modules()) + + def feature_channels(self, idx=None): + """ Feature Channel Shortcut + Returns feature channel count for each output index if idx == None. If idx is an integer, will + return feature channel count for that feature block index (independent of out_indices setting). + """ + if isinstance(idx, int): + return self.feature_info[idx]['num_chs'] + return [self.feature_info[i]['num_chs'] for i in self.out_indices] + + def forward(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + self.blocks(x) + return self.feature_hooks.get_output(x.device) + + +def _create_model(model_kwargs, default_cfg, pretrained=False): + if model_kwargs.pop('features_only', False): + load_strict = False + model_kwargs.pop('num_classes', 0) + model_kwargs.pop('num_features', 0) + model_kwargs.pop('head_conv', None) + model_class = MobileNetV3Features + else: + load_strict = True + model_class = MobileNetV3 + + model = model_class(**model_kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained( + model, + default_cfg, + num_classes=model_kwargs.get('num_classes', 0), + in_chans=model_kwargs.get('in_chans', 3), + strict=load_strict) + return model + + +def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 model. + + Ref impl: ? + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + head_bias=False, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=HardSwish, + se_kwargs=dict(gate_fn=hard_sigmoid, reduce_mid=True, divisor=1), + **kwargs, + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + +def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 model. + + Ref impl: ? + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + if 'small' in variant: + num_features = 1024 + if 'minimal' in variant: + act_layer = nn.ReLU + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16'], + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'], + # stage 2, 28x28 in + ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'], + # stage 3, 14x14 in + ['ir_r2_k3_s1_e3_c48'], + # stage 4, 14x14in + ['ir_r3_k3_s2_e6_c96'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], + ] + else: + act_layer = HardSwish + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu + # stage 2, 28x28 in + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish + # stage 3, 14x14 in + ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish + # stage 4, 14x14in + ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], # hard-swish + ] + else: + num_features = 1280 + if 'minimal' in variant: + act_layer = nn.ReLU + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k3_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112'], + # stage 5, 14x14in + ['ir_r3_k3_s2_e6_c160'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], + ] + else: + act_layer = HardSwish + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=16, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=act_layer, + se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8), + **kwargs, + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + +@register_model +def mobilenetv3_rw(pretrained=False, **kwargs): + """ MobileNet V3 """ + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs) + return model + + + +@register_model +def tf_mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model