diff --git a/timm/models/conv2d_layers.py b/timm/models/conv2d_layers.py index cd52b885..ea72d07c 100644 --- a/timm/models/conv2d_layers.py +++ b/timm/models/conv2d_layers.py @@ -8,6 +8,7 @@ import numpy as np import math +# Tuple helpers ripped from PyTorch def _ntuple(n): def parse(x): if isinstance(x, container_abcs.Iterable): @@ -77,7 +78,7 @@ def get_padding_value(padding, kernel_size, **kwargs): # static case, no extra overhead padding = _get_padding(kernel_size, **kwargs) else: - # dynamic padding + # dynamic 'SAME' padding, has runtime/GPU memory overhead padding = 0 dynamic = True elif padding == 'valid': @@ -101,6 +102,7 @@ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): class MixedConv2d(nn.Module): """ Mixed Grouped Convolution + Based on MDConv and GroupedConv in MixNet impl: https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py """ @@ -152,7 +154,11 @@ def get_condconv_initializer(initializer, num_experts, expert_shape): class CondConv2d(nn.Module): """ Conditional Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 """ def __init__(self, in_channels, out_channels, kernel_size=3, @@ -211,6 +217,7 @@ class CondConv2d(nn.Module): if self._use_groups: new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size weight = weight.view(new_weight_shape) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel x = x.view(1, B * C, H, W) out = self.conv_fn( x, weight, bias, stride=self.stride, padding=self.padding, diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py index e51bab2a..216ea6ee 100644 --- a/timm/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -2,6 +2,8 @@ A generic class with building blocks to support a variety of models with efficient architectures: * EfficientNet (B0-B7) +* EfficientNet-EdgeTPU +* EfficientNet-CondConv * MixNet (Small, Medium, and Large) * MnasNet B1, A1 (SE), Small * MobileNet V1, V2, and V3 @@ -31,6 +33,7 @@ from .registry import register_model, model_entrypoint from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d from .conv2d_layers import select_conv2d +from .layers import Flatten from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD @@ -1050,16 +1053,14 @@ class GenEfficientNet(_GenEfficientNet): layers = [self.conv_stem, self.bn1, self.act1] layers.extend(self.blocks) if self.head_conv == 'efficient': - layers.extend([self.global_pool, self.bn2, self.act2]) + layers.extend([self.global_pool, self.conv_head, self.act2]) else: layers.extend([self.conv_head, self.bn2, self.act2]) if self.global_pool is not None: layers.append(self.global_pool) - #append flatten layer - layers.append(self.classifier) + layers.extend([Flatten(), nn.Dropout(self.drop_rate), self.classifier]) return nn.Sequential(*layers) - def get_classifier(self): return self.classifier @@ -1106,7 +1107,8 @@ class GenEfficientNetFeatures(_GenEfficientNet): #assert len(block_args) >= num_stages - 1 #block_args = block_args[:num_stages - 1] - super(GenEfficientNetFeatures, self).__init__( # FIXME it would be nice if Python made this nicer + # FIXME it would be nice if Python made this nicer without using kwargs and erasing IDE hints, etc + super(GenEfficientNetFeatures, self).__init__( block_args, in_chans=in_chans, stem_size=stem_size, output_stride=output_stride, pad_type=pad_type, act_layer=act_layer, drop_rate=drop_rate, drop_connect_rate=drop_connect_rate, feature_location=feature_location, @@ -1548,6 +1550,11 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-EdgeTPU model + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu + """ + arch_def = [ # NOTE `fc` is present to override a mismatch between stem channels and in chs not # present in other models @@ -1573,8 +1580,10 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0 def _gen_efficientnet_condconv( variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs): + """Creates an EfficientNet-CondConv model. - """Creates an efficientnet-condconv model.""" + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv + """ arch_def = [ ['ds_r1_k3_s1_e1_c16_se0.25'], ['ir_r2_k3_s2_e6_c24_se0.25'], @@ -1584,6 +1593,8 @@ def _gen_efficientnet_condconv( ['ir_r4_k5_s2_e6_c192_se0.25_cc4'], ['ir_r1_k3_s1_e6_c320_se0.25_cc4'], ] + # NOTE unlike official impl, this one uses `cc` option where x is the base number of experts for each stage and + # the expert_multiplier increases that on a per-model basis as with depth/channel multipliers model_kwargs = dict( block_args=_decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), num_features=_round_channels(1280, channel_multiplier, 8, None), @@ -2056,7 +2067,7 @@ def tf_efficientnet_el(pretrained=False, **kwargs): @register_model def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs): - """ EfficientNet-B0 """ + """ EfficientNet-CondConv-B0 w/ 4 Experts. Tensorflow compatible variant """ # NOTE for train, drop_rate should be 0.2 #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT @@ -2068,7 +2079,7 @@ def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs): @register_model def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs): - """ EfficientNet-B0 """ + """ EfficientNet-CondConv-B0 w/ 8 Experts. Tensorflow compatible variant """ # NOTE for train, drop_rate should be 0.2 #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT @@ -2080,7 +2091,7 @@ def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs): @register_model def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): - """ EfficientNet-B0 """ + """ EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """ # NOTE for train, drop_rate should be 0.2 #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT diff --git a/timm/models/layers.py b/timm/models/layers.py new file mode 100644 index 00000000..c8e0a837 --- /dev/null +++ b/timm/models/layers.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def versiontuple(v): + return tuple(map(int, (v.split("."))))[:3] + + +if versiontuple(torch.__version__) >= versiontuple('1.2.0'): + Flatten = nn.Flatten +else: + class Flatten(nn.Module): + r""" + Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`. + Args: + start_dim: first dim to flatten (default = 1). + end_dim: last dim to flatten (default = -1). + Shape: + - Input: :math:`(N, *dims)` + - Output: :math:`(N, \prod *dims)` (for the default case). + """ + __constants__ = ['start_dim', 'end_dim'] + + def __init__(self, start_dim=1, end_dim=-1): + super(Flatten, self).__init__() + self.start_dim = start_dim + self.end_dim = end_dim + + def forward(self, input): + return input.flatten(self.start_dim, self.end_dim)