Update EfficientNet feature extraction for EfficientDet. Add needed MaxPoolSame as well.

pull/123/head
Ross Wightman 5 years ago
parent e01ccb88ce
commit 1a8f5900ab

@ -326,7 +326,6 @@ class EfficientNet(nn.Module):
# Stem # Stem
if not fix_stem: if not fix_stem:
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
print(stem_size)
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs) self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
@ -393,7 +392,7 @@ class EfficientNetFeatures(nn.Module):
and object detection models. and object detection models.
""" """
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl', def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None, in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None): se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
@ -404,6 +403,7 @@ class EfficientNetFeatures(nn.Module):
num_stages = max(out_indices) + 1 num_stages = max(out_indices) + 1
self.out_indices = out_indices self.out_indices = out_indices
self.feature_location = feature_location
self.drop_rate = drop_rate self.drop_rate = drop_rate
self._in_chs = in_chans self._in_chs = in_chans
@ -420,18 +420,23 @@ class EfficientNetFeatures(nn.Module):
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs, channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = builder.features # builder provides info about feature channels for each block self._feature_info = builder.features # builder provides info about feature channels for each block
self._stage_to_feature_idx = {
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
self._in_chs = builder.in_chs self._in_chs = builder.in_chs
efficientnet_init_weights(self) efficientnet_init_weights(self)
if _DEBUG: if _DEBUG:
for k, v in self.feature_info.items(): for k, v in self._feature_info.items():
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs'])) print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
# Register feature extraction hooks with FeatureHooks helper # Register feature extraction hooks with FeatureHooks helper
hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward' self.feature_hooks = None
hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices] if feature_location != 'bottleneck':
self.feature_hooks = FeatureHooks(hooks, self.named_modules()) hooks = [dict(
name=self._feature_info[idx]['module'],
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
def feature_channels(self, idx=None): def feature_channels(self, idx=None):
""" Feature Channel Shortcut """ Feature Channel Shortcut
@ -439,15 +444,32 @@ class EfficientNetFeatures(nn.Module):
return feature channel count for that feature block index (independent of out_indices setting). return feature channel count for that feature block index (independent of out_indices setting).
""" """
if isinstance(idx, int): if isinstance(idx, int):
return self.feature_info[idx]['num_chs'] return self._feature_info[idx]['num_chs']
return [self.feature_info[i]['num_chs'] for i in self.out_indices] return [self._feature_info[i]['num_chs'] for i in self.out_indices]
def feature_info(self, idx=None):
""" Feature Channel Shortcut
Returns feature channel count for each output index if idx == None. If idx is an integer, will
return feature channel count for that feature block index (independent of out_indices setting).
"""
if isinstance(idx, int):
return self._feature_info[idx]
return [self._feature_info[i] for i in self.out_indices]
def forward(self, x): def forward(self, x):
x = self.conv_stem(x) x = self.conv_stem(x)
x = self.bn1(x) x = self.bn1(x)
x = self.act1(x) x = self.act1(x)
self.blocks(x) if self.feature_hooks is None:
return self.feature_hooks.get_output(x.device) features = []
for i, b in enumerate(self.blocks):
x = b(x)
if i in self._stage_to_feature_idx:
features.append(x)
return features
else:
self.blocks(x)
return self.feature_hooks.get_output(x.device)
def _create_model(model_kwargs, default_cfg, pretrained=False): def _create_model(model_kwargs, default_cfg, pretrained=False):

@ -120,11 +120,13 @@ class ConvBnAct(nn.Module):
self.bn1 = norm_layer(out_chs, **norm_kwargs) self.bn1 = norm_layer(out_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
def feature_module(self, location): def feature_info(self, location):
return 'act1' if location == 'expansion' or location == 'depthwise':
# no expansion or depthwise this block, use act after conv
def feature_channels(self, location): info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels)
return self.conv.out_channels else: # location == 'bottleneck'
info = dict(module='', hook_type='', num_chs=self.conv.out_channels)
return info
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
@ -165,12 +167,15 @@ class DepthwiseSeparableConv(nn.Module):
self.bn2 = norm_layer(out_chs, **norm_kwargs) self.bn2 = norm_layer(out_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
def feature_module(self, location): def feature_info(self, location):
# no expansion in this block, pre pw only feature extraction point if location == 'expansion':
return 'conv_pw' # no expansion in this block, use depthwise, before SE
info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels)
def feature_channels(self, location): elif location == 'depthwise': # after SE
return self.conv_pw.in_channels info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
else: # location == 'bottleneck'
info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
return info
def forward(self, x): def forward(self, x):
residual = x residual = x
@ -232,16 +237,14 @@ class InvertedResidual(nn.Module):
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
self.bn3 = norm_layer(out_chs, **norm_kwargs) self.bn3 = norm_layer(out_chs, **norm_kwargs)
def feature_module(self, location): def feature_info(self, location):
if location == 'post_exp': if location == 'expansion':
return 'act1' info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels)
return 'conv_pwl' elif location == 'depthwise': # after SE
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
def feature_channels(self, location): else: # location == 'bottleneck'
if location == 'post_exp': info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
return self.conv_pw.out_channels return info
# location == 'pre_pw'
return self.conv_pwl.in_channels
def forward(self, x): def forward(self, x):
residual = x residual = x
@ -359,16 +362,15 @@ class EdgeResidual(nn.Module):
mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type) mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
self.bn2 = norm_layer(out_chs, **norm_kwargs) self.bn2 = norm_layer(out_chs, **norm_kwargs)
def feature_module(self, location): def feature_info(self, location):
if location == 'post_exp': if location == 'expansion':
return 'act1' info = dict(module='act1', hook_type='forward', num_chs=self.conv_exp.out_channels)
return 'conv_pwl' elif location == 'depthwise':
# there is no depthwise, take after SE, before PWL
def feature_channels(self, location): info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
if location == 'post_exp': else: # location == 'bottleneck'
return self.conv_exp.out_channels info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
# location == 'pre_pw' return info
return self.conv_pwl.in_channels
def forward(self, x): def forward(self, x):
residual = x residual = x

@ -218,7 +218,7 @@ class EfficientNetBuilder:
self.norm_kwargs = norm_kwargs self.norm_kwargs = norm_kwargs
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
self.feature_location = feature_location self.feature_location = feature_location
assert feature_location in ('pre_pwl', 'post_exp', '') assert feature_location in ('bottleneck', 'depthwise', 'expansion', '')
self.verbose = verbose self.verbose = verbose
# state updated during build, consumed by model # state updated during build, consumed by model
@ -313,20 +313,21 @@ class EfficientNetBuilder:
block_args['stride'] = 1 block_args['stride'] = 1
do_extract = False do_extract = False
if self.feature_location == 'pre_pwl': if self.feature_location == 'bottleneck' or self.feature_location == 'depthwise':
if last_block: if last_block:
next_stage_idx = stage_idx + 1 next_stage_idx = stage_idx + 1
if next_stage_idx >= len(model_block_args): if next_stage_idx >= len(model_block_args):
do_extract = True do_extract = True
else: else:
do_extract = model_block_args[next_stage_idx][0]['stride'] > 1 do_extract = model_block_args[next_stage_idx][0]['stride'] > 1
elif self.feature_location == 'post_exp': elif self.feature_location == 'expansion':
if block_args['stride'] > 1 or (last_stack and last_block) : if block_args['stride'] > 1 or (last_stack and last_block):
do_extract = True do_extract = True
if do_extract: if do_extract:
extract_features = self.feature_location extract_features = self.feature_location
next_dilation = current_dilation next_dilation = current_dilation
next_output_stride = current_stride
if block_args['stride'] > 1: if block_args['stride'] > 1:
next_output_stride = current_stride * block_args['stride'] next_output_stride = current_stride * block_args['stride']
if next_output_stride > self.output_stride: if next_output_stride > self.output_stride:
@ -347,14 +348,13 @@ class EfficientNetBuilder:
# stash feature module name and channel info for model feature extraction # stash feature module name and channel info for model feature extraction
if extract_features: if extract_features:
feature_module = block.feature_module(extract_features) feature_info = block.feature_info(extract_features)
if feature_module: if feature_info['module']:
feature_module = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_module feature_info['module'] = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_info['module']
feature_channels = block.feature_channels(extract_features) feature_info['stage_idx'] = stage_idx
self.features[feature_idx] = dict( feature_info['block_idx'] = block_idx
name=feature_module, feature_info['reduction'] = current_stride
num_chs=feature_channels self.features[feature_idx] = feature_info
)
feature_idx += 1 feature_idx += 1
total_block_idx += 1 # incr global block idx (across all stacks) total_block_idx += 1 # incr global block idx (across all stacks)

@ -1,9 +1,10 @@
from .padding import get_padding from .padding import get_padding
from .avg_pool2d_same import AvgPool2dSame from .pool2d_same import AvgPool2dSame
from .conv2d_same import Conv2dSame from .conv2d_same import Conv2dSame
from .conv_bn_act import ConvBnAct from .conv_bn_act import ConvBnAct
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .cond_conv2d import CondConv2d, get_condconv_initializer from .cond_conv2d import CondConv2d, get_condconv_initializer
from .pool2d_same import create_pool2d
from .create_conv2d import create_conv2d from .create_conv2d import create_conv2d
from .create_attn import create_attn from .create_attn import create_attn
from .selective_kernel import SelectiveKernelConv from .selective_kernel import SelectiveKernelConv

@ -1,31 +0,0 @@
""" AvgPool2d w/ Same Padding
Hacked together by Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
import math
from .helpers import tup_pair
from .padding import pad_same
def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
ceil_mode: bool = False, count_include_pad: bool = True):
x = pad_same(x, kernel_size, stride)
return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
class AvgPool2dSame(nn.AvgPool2d):
""" Tensorflow like 'SAME' wrapper for 2D average pooling
"""
def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
kernel_size = tup_pair(kernel_size)
stride = tup_pair(stride)
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
def forward(self, x):
return avg_pool2d_same(
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)

@ -14,7 +14,8 @@ from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from .helpers import tup_pair from .helpers import tup_pair
from .conv2d_same import get_padding_value, conv2d_same from .conv2d_same import conv2d_same
from timm.models.layers.padding import get_padding_value
def get_condconv_initializer(initializer, num_experts, expert_shape): def get_condconv_initializer(initializer, num_experts, expert_shape):

@ -5,10 +5,10 @@ Hacked together by Ross Wightman
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import Union, List, Tuple, Optional, Callable from typing import Tuple, Optional
import math
from .padding import get_padding, pad_same, is_static_pad from timm.models.layers.padding import get_padding_value
from .padding import pad_same
def conv2d_same( def conv2d_same(
@ -31,29 +31,6 @@ class Conv2dSame(nn.Conv2d):
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
dynamic = False
if isinstance(padding, str):
# for any string padding, the padding will be calculated for you, one of three ways
padding = padding.lower()
if padding == 'same':
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
if is_static_pad(kernel_size, **kwargs):
# static case, no extra overhead
padding = get_padding(kernel_size, **kwargs)
else:
# dynamic 'SAME' padding, has runtime/GPU memory overhead
padding = 0
dynamic = True
elif padding == 'valid':
# 'VALID' padding, same as padding=0
padding = 0
else:
# Default to PyTorch style 'same'-ish symmetric padding
padding = get_padding(kernel_size, **kwargs)
return padding, dynamic
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
padding = kwargs.pop('padding', '') padding = kwargs.pop('padding', '')
kwargs.setdefault('bias', False) kwargs.setdefault('bias', False)

@ -3,7 +3,7 @@
Hacked together by Ross Wightman Hacked together by Ross Wightman
""" """
import math import math
from typing import List from typing import List, Tuple
import torch.nn.functional as F import torch.nn.functional as F
@ -25,9 +25,32 @@ def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
# Dynamically pad input x with 'SAME' padding for conv with specified args # Dynamically pad input x with 'SAME' padding for conv with specified args
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1)): def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
ih, iw = x.size()[-2:] ih, iw = x.size()[-2:]
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
if pad_h > 0 or pad_w > 0: if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
return x return x
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
dynamic = False
if isinstance(padding, str):
# for any string padding, the padding will be calculated for you, one of three ways
padding = padding.lower()
if padding == 'same':
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
if is_static_pad(kernel_size, **kwargs):
# static case, no extra overhead
padding = get_padding(kernel_size, **kwargs)
else:
# dynamic 'SAME' padding, has runtime/GPU memory overhead
padding = 0
dynamic = True
elif padding == 'valid':
# 'VALID' padding, same as padding=0
padding = 0
else:
# Default to PyTorch style 'same'-ish symmetric padding
padding = get_padding(kernel_size, **kwargs)
return padding, dynamic

@ -0,0 +1,71 @@
""" AvgPool2d w/ Same Padding
Hacked together by Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union, List, Tuple, Optional
import math
from .helpers import tup_pair
from .padding import pad_same, get_padding_value
def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
ceil_mode: bool = False, count_include_pad: bool = True):
# FIXME how to deal with count_include_pad vs not for external padding?
x = pad_same(x, kernel_size, stride)
return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
class AvgPool2dSame(nn.AvgPool2d):
""" Tensorflow like 'SAME' wrapper for 2D average pooling
"""
def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
kernel_size = tup_pair(kernel_size)
stride = tup_pair(stride)
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
def forward(self, x):
return avg_pool2d_same(
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
def max_pool2d_same(
x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
dilation: List[int] = (1, 1), ceil_mode: bool = False):
x = pad_same(x, kernel_size, stride, value=-float('inf'))
return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)
class MaxPool2dSame(nn.MaxPool2d):
""" Tensorflow like 'SAME' wrapper for 2D max pooling
"""
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True):
kernel_size = tup_pair(kernel_size)
stride = tup_pair(stride)
super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad)
def forward(self, x):
return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode)
def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):
stride = stride or kernel_size
padding = kwargs.pop('padding', '')
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
if is_dynamic:
if pool_type == 'avg':
return AvgPool2dSame(kernel_size, stride=stride, **kwargs)
elif pool_type == 'max':
return MaxPool2dSame(kernel_size, stride=stride, **kwargs)
else:
assert False, f'Unsupported pool type {pool_type}'
else:
if pool_type == 'avg':
return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
elif pool_type == 'max':
return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
else:
assert False, f'Unsupported pool type {pool_type}'
Loading…
Cancel
Save