Fixup a few comments, add PyTorch version aware Flatten and finish as_sequential for GenEfficientNet

pull/53/head
Ross Wightman 5 years ago
parent 7ac6db4543
commit 35e8f0c5e7

@ -8,6 +8,7 @@ import numpy as np
import math
# Tuple helpers ripped from PyTorch
def _ntuple(n):
def parse(x):
if isinstance(x, container_abcs.Iterable):
@ -77,7 +78,7 @@ def get_padding_value(padding, kernel_size, **kwargs):
# static case, no extra overhead
padding = _get_padding(kernel_size, **kwargs)
else:
# dynamic padding
# dynamic 'SAME' padding, has runtime/GPU memory overhead
padding = 0
dynamic = True
elif padding == 'valid':
@ -101,6 +102,7 @@ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
class MixedConv2d(nn.Module):
""" Mixed Grouped Convolution
Based on MDConv and GroupedConv in MixNet impl:
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
"""
@ -152,7 +154,11 @@ def get_condconv_initializer(initializer, num_experts, expert_shape):
class CondConv2d(nn.Module):
""" Conditional Convolution
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
https://github.com/pytorch/pytorch/issues/17983
"""
def __init__(self, in_channels, out_channels, kernel_size=3,
@ -211,6 +217,7 @@ class CondConv2d(nn.Module):
if self._use_groups:
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
weight = weight.view(new_weight_shape)
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
x = x.view(1, B * C, H, W)
out = self.conv_fn(
x, weight, bias, stride=self.stride, padding=self.padding,

@ -2,6 +2,8 @@
A generic class with building blocks to support a variety of models with efficient architectures:
* EfficientNet (B0-B7)
* EfficientNet-EdgeTPU
* EfficientNet-CondConv
* MixNet (Small, Medium, and Large)
* MnasNet B1, A1 (SE), Small
* MobileNet V1, V2, and V3
@ -31,6 +33,7 @@ from .registry import register_model, model_entrypoint
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .conv2d_layers import select_conv2d
from .layers import Flatten
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@ -1050,16 +1053,14 @@ class GenEfficientNet(_GenEfficientNet):
layers = [self.conv_stem, self.bn1, self.act1]
layers.extend(self.blocks)
if self.head_conv == 'efficient':
layers.extend([self.global_pool, self.bn2, self.act2])
layers.extend([self.global_pool, self.conv_head, self.act2])
else:
layers.extend([self.conv_head, self.bn2, self.act2])
if self.global_pool is not None:
layers.append(self.global_pool)
#append flatten layer
layers.append(self.classifier)
layers.extend([Flatten(), nn.Dropout(self.drop_rate), self.classifier])
return nn.Sequential(*layers)
def get_classifier(self):
return self.classifier
@ -1106,7 +1107,8 @@ class GenEfficientNetFeatures(_GenEfficientNet):
#assert len(block_args) >= num_stages - 1
#block_args = block_args[:num_stages - 1]
super(GenEfficientNetFeatures, self).__init__( # FIXME it would be nice if Python made this nicer
# FIXME it would be nice if Python made this nicer without using kwargs and erasing IDE hints, etc
super(GenEfficientNetFeatures, self).__init__(
block_args, in_chans=in_chans, stem_size=stem_size,
output_stride=output_stride, pad_type=pad_type, act_layer=act_layer,
drop_rate=drop_rate, drop_connect_rate=drop_connect_rate, feature_location=feature_location,
@ -1548,6 +1550,11 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
""" Creates an EfficientNet-EdgeTPU model
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu
"""
arch_def = [
# NOTE `fc` is present to override a mismatch between stem channels and in chs not
# present in other models
@ -1573,8 +1580,10 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
def _gen_efficientnet_condconv(
variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs):
"""Creates an EfficientNet-CondConv model.
"""Creates an efficientnet-condconv model."""
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
"""
arch_def = [
['ds_r1_k3_s1_e1_c16_se0.25'],
['ir_r2_k3_s2_e6_c24_se0.25'],
@ -1584,6 +1593,8 @@ def _gen_efficientnet_condconv(
['ir_r4_k5_s2_e6_c192_se0.25_cc4'],
['ir_r1_k3_s1_e6_c320_se0.25_cc4'],
]
# NOTE unlike official impl, this one uses `cc<x>` option where x is the base number of experts for each stage and
# the expert_multiplier increases that on a per-model basis as with depth/channel multipliers
model_kwargs = dict(
block_args=_decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier),
num_features=_round_channels(1280, channel_multiplier, 8, None),
@ -2056,7 +2067,7 @@ def tf_efficientnet_el(pretrained=False, **kwargs):
@register_model
def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
""" EfficientNet-B0 """
""" EfficientNet-CondConv-B0 w/ 4 Experts. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.2
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
@ -2068,7 +2079,7 @@ def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
@register_model
def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
""" EfficientNet-B0 """
""" EfficientNet-CondConv-B0 w/ 8 Experts. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.2
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
@ -2080,7 +2091,7 @@ def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
@register_model
def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
""" EfficientNet-B0 """
""" EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.2
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT

@ -0,0 +1,31 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
def versiontuple(v):
return tuple(map(int, (v.split("."))))[:3]
if versiontuple(torch.__version__) >= versiontuple('1.2.0'):
Flatten = nn.Flatten
else:
class Flatten(nn.Module):
r"""
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
Args:
start_dim: first dim to flatten (default = 1).
end_dim: last dim to flatten (default = -1).
Shape:
- Input: :math:`(N, *dims)`
- Output: :math:`(N, \prod *dims)` (for the default case).
"""
__constants__ = ['start_dim', 'end_dim']
def __init__(self, start_dim=1, end_dim=-1):
super(Flatten, self).__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, input):
return input.flatten(self.start_dim, self.end_dim)
Loading…
Cancel
Save