diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 48678e67..4ef152d5 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -326,7 +326,6 @@ class EfficientNet(nn.Module): # Stem if not fix_stem: stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) - print(stem_size) self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) @@ -393,7 +392,7 @@ class EfficientNetFeatures(nn.Module): and object detection models. """ - def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl', + def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None, output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None): @@ -404,6 +403,7 @@ class EfficientNetFeatures(nn.Module): num_stages = max(out_indices) + 1 self.out_indices = out_indices + self.feature_location = feature_location self.drop_rate = drop_rate self._in_chs = in_chans @@ -420,18 +420,23 @@ class EfficientNetFeatures(nn.Module): channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) - self.feature_info = builder.features # builder provides info about feature channels for each block + self._feature_info = builder.features # builder provides info about feature channels for each block + self._stage_to_feature_idx = { + v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices} self._in_chs = builder.in_chs efficientnet_init_weights(self) if _DEBUG: - for k, v in self.feature_info.items(): + for k, v in self._feature_info.items(): print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs'])) # Register feature extraction hooks with FeatureHooks helper - hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward' - hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices] - self.feature_hooks = FeatureHooks(hooks, self.named_modules()) + self.feature_hooks = None + if feature_location != 'bottleneck': + hooks = [dict( + name=self._feature_info[idx]['module'], + type=self._feature_info[idx]['hook_type']) for idx in out_indices] + self.feature_hooks = FeatureHooks(hooks, self.named_modules()) def feature_channels(self, idx=None): """ Feature Channel Shortcut @@ -439,15 +444,32 @@ class EfficientNetFeatures(nn.Module): return feature channel count for that feature block index (independent of out_indices setting). """ if isinstance(idx, int): - return self.feature_info[idx]['num_chs'] - return [self.feature_info[i]['num_chs'] for i in self.out_indices] + return self._feature_info[idx]['num_chs'] + return [self._feature_info[i]['num_chs'] for i in self.out_indices] + + def feature_info(self, idx=None): + """ Feature Channel Shortcut + Returns feature channel count for each output index if idx == None. If idx is an integer, will + return feature channel count for that feature block index (independent of out_indices setting). + """ + if isinstance(idx, int): + return self._feature_info[idx] + return [self._feature_info[i] for i in self.out_indices] def forward(self, x): x = self.conv_stem(x) x = self.bn1(x) x = self.act1(x) - self.blocks(x) - return self.feature_hooks.get_output(x.device) + if self.feature_hooks is None: + features = [] + for i, b in enumerate(self.blocks): + x = b(x) + if i in self._stage_to_feature_idx: + features.append(x) + return features + else: + self.blocks(x) + return self.feature_hooks.get_output(x.device) def _create_model(model_kwargs, default_cfg, pretrained=False): diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index b5de664d..cc4cdef1 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -120,11 +120,13 @@ class ConvBnAct(nn.Module): self.bn1 = norm_layer(out_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) - def feature_module(self, location): - return 'act1' - - def feature_channels(self, location): - return self.conv.out_channels + def feature_info(self, location): + if location == 'expansion' or location == 'depthwise': + # no expansion or depthwise this block, use act after conv + info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels) + else: # location == 'bottleneck' + info = dict(module='', hook_type='', num_chs=self.conv.out_channels) + return info def forward(self, x): x = self.conv(x) @@ -165,12 +167,15 @@ class DepthwiseSeparableConv(nn.Module): self.bn2 = norm_layer(out_chs, **norm_kwargs) self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() - def feature_module(self, location): - # no expansion in this block, pre pw only feature extraction point - return 'conv_pw' - - def feature_channels(self, location): - return self.conv_pw.in_channels + def feature_info(self, location): + if location == 'expansion': + # no expansion in this block, use depthwise, before SE + info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels) + elif location == 'depthwise': # after SE + info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) + else: # location == 'bottleneck' + info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels) + return info def forward(self, x): residual = x @@ -232,16 +237,14 @@ class InvertedResidual(nn.Module): self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) self.bn3 = norm_layer(out_chs, **norm_kwargs) - def feature_module(self, location): - if location == 'post_exp': - return 'act1' - return 'conv_pwl' - - def feature_channels(self, location): - if location == 'post_exp': - return self.conv_pw.out_channels - # location == 'pre_pw' - return self.conv_pwl.in_channels + def feature_info(self, location): + if location == 'expansion': + info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels) + elif location == 'depthwise': # after SE + info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) + else: # location == 'bottleneck' + info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) + return info def forward(self, x): residual = x @@ -359,16 +362,15 @@ class EdgeResidual(nn.Module): mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type) self.bn2 = norm_layer(out_chs, **norm_kwargs) - def feature_module(self, location): - if location == 'post_exp': - return 'act1' - return 'conv_pwl' - - def feature_channels(self, location): - if location == 'post_exp': - return self.conv_exp.out_channels - # location == 'pre_pw' - return self.conv_pwl.in_channels + def feature_info(self, location): + if location == 'expansion': + info = dict(module='act1', hook_type='forward', num_chs=self.conv_exp.out_channels) + elif location == 'depthwise': + # there is no depthwise, take after SE, before PWL + info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) + else: # location == 'bottleneck' + info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) + return info def forward(self, x): residual = x diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index 3876a2d1..842098cf 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -218,7 +218,7 @@ class EfficientNetBuilder: self.norm_kwargs = norm_kwargs self.drop_path_rate = drop_path_rate self.feature_location = feature_location - assert feature_location in ('pre_pwl', 'post_exp', '') + assert feature_location in ('bottleneck', 'depthwise', 'expansion', '') self.verbose = verbose # state updated during build, consumed by model @@ -313,20 +313,21 @@ class EfficientNetBuilder: block_args['stride'] = 1 do_extract = False - if self.feature_location == 'pre_pwl': + if self.feature_location == 'bottleneck' or self.feature_location == 'depthwise': if last_block: next_stage_idx = stage_idx + 1 if next_stage_idx >= len(model_block_args): do_extract = True else: do_extract = model_block_args[next_stage_idx][0]['stride'] > 1 - elif self.feature_location == 'post_exp': - if block_args['stride'] > 1 or (last_stack and last_block) : + elif self.feature_location == 'expansion': + if block_args['stride'] > 1 or (last_stack and last_block): do_extract = True if do_extract: extract_features = self.feature_location next_dilation = current_dilation + next_output_stride = current_stride if block_args['stride'] > 1: next_output_stride = current_stride * block_args['stride'] if next_output_stride > self.output_stride: @@ -347,14 +348,13 @@ class EfficientNetBuilder: # stash feature module name and channel info for model feature extraction if extract_features: - feature_module = block.feature_module(extract_features) - if feature_module: - feature_module = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_module - feature_channels = block.feature_channels(extract_features) - self.features[feature_idx] = dict( - name=feature_module, - num_chs=feature_channels - ) + feature_info = block.feature_info(extract_features) + if feature_info['module']: + feature_info['module'] = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_info['module'] + feature_info['stage_idx'] = stage_idx + feature_info['block_idx'] = block_idx + feature_info['reduction'] = current_stride + self.features[feature_idx] = feature_info feature_idx += 1 total_block_idx += 1 # incr global block idx (across all stacks) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index f012c3cf..3dec7498 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -1,9 +1,10 @@ from .padding import get_padding -from .avg_pool2d_same import AvgPool2dSame +from .pool2d_same import AvgPool2dSame from .conv2d_same import Conv2dSame from .conv_bn_act import ConvBnAct from .mixed_conv2d import MixedConv2d from .cond_conv2d import CondConv2d, get_condconv_initializer +from .pool2d_same import create_pool2d from .create_conv2d import create_conv2d from .create_attn import create_attn from .selective_kernel import SelectiveKernelConv diff --git a/timm/models/layers/avg_pool2d_same.py b/timm/models/layers/avg_pool2d_same.py deleted file mode 100644 index 33656e79..00000000 --- a/timm/models/layers/avg_pool2d_same.py +++ /dev/null @@ -1,31 +0,0 @@ -""" AvgPool2d w/ Same Padding - -Hacked together by Ross Wightman -""" -import torch -import torch.nn as nn -import torch.nn.functional as F -from typing import List -import math - -from .helpers import tup_pair -from .padding import pad_same - - -def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), - ceil_mode: bool = False, count_include_pad: bool = True): - x = pad_same(x, kernel_size, stride) - return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) - - -class AvgPool2dSame(nn.AvgPool2d): - """ Tensorflow like 'SAME' wrapper for 2D average pooling - """ - def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): - kernel_size = tup_pair(kernel_size) - stride = tup_pair(stride) - super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) - - def forward(self, x): - return avg_pool2d_same( - x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) diff --git a/timm/models/layers/cond_conv2d.py b/timm/models/layers/cond_conv2d.py index 7b038ee7..0241b501 100644 --- a/timm/models/layers/cond_conv2d.py +++ b/timm/models/layers/cond_conv2d.py @@ -14,7 +14,8 @@ from torch import nn as nn from torch.nn import functional as F from .helpers import tup_pair -from .conv2d_same import get_padding_value, conv2d_same +from .conv2d_same import conv2d_same +from timm.models.layers.padding import get_padding_value def get_condconv_initializer(initializer, num_experts, expert_shape): diff --git a/timm/models/layers/conv2d_same.py b/timm/models/layers/conv2d_same.py index 0e29ae8c..863d1783 100644 --- a/timm/models/layers/conv2d_same.py +++ b/timm/models/layers/conv2d_same.py @@ -5,10 +5,10 @@ Hacked together by Ross Wightman import torch import torch.nn as nn import torch.nn.functional as F -from typing import Union, List, Tuple, Optional, Callable -import math +from typing import Tuple, Optional -from .padding import get_padding, pad_same, is_static_pad +from timm.models.layers.padding import get_padding_value +from .padding import pad_same def conv2d_same( @@ -31,29 +31,6 @@ class Conv2dSame(nn.Conv2d): return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) -def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: - dynamic = False - if isinstance(padding, str): - # for any string padding, the padding will be calculated for you, one of three ways - padding = padding.lower() - if padding == 'same': - # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact - if is_static_pad(kernel_size, **kwargs): - # static case, no extra overhead - padding = get_padding(kernel_size, **kwargs) - else: - # dynamic 'SAME' padding, has runtime/GPU memory overhead - padding = 0 - dynamic = True - elif padding == 'valid': - # 'VALID' padding, same as padding=0 - padding = 0 - else: - # Default to PyTorch style 'same'-ish symmetric padding - padding = get_padding(kernel_size, **kwargs) - return padding, dynamic - - def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): padding = kwargs.pop('padding', '') kwargs.setdefault('bias', False) diff --git a/timm/models/layers/padding.py b/timm/models/layers/padding.py index b3653866..0fca7cc6 100644 --- a/timm/models/layers/padding.py +++ b/timm/models/layers/padding.py @@ -3,7 +3,7 @@ Hacked together by Ross Wightman """ import math -from typing import List +from typing import List, Tuple import torch.nn.functional as F @@ -25,9 +25,32 @@ def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): # Dynamically pad input x with 'SAME' padding for conv with specified args -def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1)): +def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): ih, iw = x.size()[-2:] pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) return x + + +def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = get_padding(kernel_size, **kwargs) + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = get_padding(kernel_size, **kwargs) + return padding, dynamic diff --git a/timm/models/layers/pool2d_same.py b/timm/models/layers/pool2d_same.py new file mode 100644 index 00000000..40f6dacc --- /dev/null +++ b/timm/models/layers/pool2d_same.py @@ -0,0 +1,71 @@ +""" AvgPool2d w/ Same Padding + +Hacked together by Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Union, List, Tuple, Optional +import math + +from .helpers import tup_pair +from .padding import pad_same, get_padding_value + + +def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), + ceil_mode: bool = False, count_include_pad: bool = True): + # FIXME how to deal with count_include_pad vs not for external padding? + x = pad_same(x, kernel_size, stride) + return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + +class AvgPool2dSame(nn.AvgPool2d): + """ Tensorflow like 'SAME' wrapper for 2D average pooling + """ + def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): + kernel_size = tup_pair(kernel_size) + stride = tup_pair(stride) + super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + def forward(self, x): + return avg_pool2d_same( + x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) + + +def max_pool2d_same( + x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), + dilation: List[int] = (1, 1), ceil_mode: bool = False): + x = pad_same(x, kernel_size, stride, value=-float('inf')) + return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) + + +class MaxPool2dSame(nn.MaxPool2d): + """ Tensorflow like 'SAME' wrapper for 2D max pooling + """ + def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True): + kernel_size = tup_pair(kernel_size) + stride = tup_pair(stride) + super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad) + + def forward(self, x): + return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode) + + +def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): + stride = stride or kernel_size + padding = kwargs.pop('padding', '') + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) + if is_dynamic: + if pool_type == 'avg': + return AvgPool2dSame(kernel_size, stride=stride, **kwargs) + elif pool_type == 'max': + return MaxPool2dSame(kernel_size, stride=stride, **kwargs) + else: + assert False, f'Unsupported pool type {pool_type}' + else: + if pool_type == 'avg': + return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) + elif pool_type == 'max': + return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) + else: + assert False, f'Unsupported pool type {pool_type}'