Merge pull request #53 from rwightman/condconvs_and_features
Major model merge (EfficientNet-CondConv, EfficientNet-AdvProp, TF MobileNetV3, HRNet, more)pull/62/head v0.1-hrnet
commit
3ceeedc441
@ -1,3 +1,3 @@
|
||||
torch>=1.1.0
|
||||
torchvision>=0.3.0
|
||||
torch>=1.2.0
|
||||
torchvision>=0.4.0
|
||||
pyyaml
|
||||
|
@ -0,0 +1,155 @@
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
_USE_MEM_EFFICIENT_ISH = True
|
||||
if _USE_MEM_EFFICIENT_ISH:
|
||||
# This version reduces memory overhead of Swish during training by
|
||||
# recomputing torch.sigmoid(x) in backward instead of saving it.
|
||||
@torch.jit.script
|
||||
def swish_jit_fwd(x):
|
||||
return x.mul(torch.sigmoid(x))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish_jit_bwd(x, grad_output):
|
||||
x_sigmoid = torch.sigmoid(x)
|
||||
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
|
||||
|
||||
|
||||
class SwishJitAutoFn(torch.autograd.Function):
|
||||
""" torch.jit.script optimised Swish
|
||||
Inspired by conversation btw Jeremy Howard & Adam Pazske
|
||||
https://twitter.com/jeremyphoward/status/1188251041835315200
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return swish_jit_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return swish_jit_bwd(x, grad_output)
|
||||
|
||||
|
||||
def swish(x, _inplace=False):
|
||||
return SwishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def mish_jit_fwd(x):
|
||||
return x.mul(torch.tanh(F.softplus(x)))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def mish_jit_bwd(x, grad_output):
|
||||
x_sigmoid = torch.sigmoid(x)
|
||||
x_tanh_sp = F.softplus(x).tanh()
|
||||
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
|
||||
|
||||
|
||||
class MishJitAutoFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return mish_jit_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return mish_jit_bwd(x, grad_output)
|
||||
|
||||
def mish(x, _inplace=False):
|
||||
return MishJitAutoFn.apply(x)
|
||||
|
||||
else:
|
||||
def swish(x, inplace=False):
|
||||
"""Swish - Described in: https://arxiv.org/abs/1710.05941
|
||||
"""
|
||||
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
|
||||
|
||||
|
||||
def mish(x, _inplace=False):
|
||||
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||
"""
|
||||
return x.mul(F.softplus(x).tanh())
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __init__(self, inplace=False):
|
||||
super(Swish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return swish(x, self.inplace)
|
||||
|
||||
|
||||
class Mish(nn.Module):
|
||||
def __init__(self, inplace=False):
|
||||
super(Mish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return mish(x, self.inplace)
|
||||
|
||||
|
||||
def sigmoid(x, inplace=False):
|
||||
return x.sigmoid_() if inplace else x.sigmoid()
|
||||
|
||||
|
||||
# PyTorch has this, but not with a consistent inplace argmument interface
|
||||
class Sigmoid(nn.Module):
|
||||
def __init__(self, inplace=False):
|
||||
super(Sigmoid, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return x.sigmoid_() if self.inplace else x.sigmoid()
|
||||
|
||||
|
||||
def tanh(x, inplace=False):
|
||||
return x.tanh_() if inplace else x.tanh()
|
||||
|
||||
|
||||
# PyTorch has this, but not with a consistent inplace argmument interface
|
||||
class Tanh(nn.Module):
|
||||
def __init__(self, inplace=False):
|
||||
super(Tanh, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return x.tanh_() if self.inplace else x.tanh()
|
||||
|
||||
|
||||
def hard_swish(x, inplace=False):
|
||||
inner = F.relu6(x + 3.).div_(6.)
|
||||
return x.mul_(inner) if inplace else x.mul(inner)
|
||||
|
||||
|
||||
class HardSwish(nn.Module):
|
||||
def __init__(self, inplace=False):
|
||||
super(HardSwish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return hard_swish(x, self.inplace)
|
||||
|
||||
|
||||
def hard_sigmoid(x, inplace=False):
|
||||
if inplace:
|
||||
return x.add_(3.).clamp_(0., 6.).div_(6.)
|
||||
else:
|
||||
return F.relu6(x + 3.) / 6.
|
||||
|
||||
|
||||
class HardSigmoid(nn.Module):
|
||||
def __init__(self, inplace=False):
|
||||
super(HardSigmoid, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return hard_sigmoid(x, self.inplace)
|
||||
|
@ -1,120 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
|
||||
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
||||
|
||||
|
||||
def _get_padding(kernel_size, stride=1, dilation=1, **_):
|
||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||
return padding
|
||||
|
||||
|
||||
def _calc_same_pad(i, k, s, d):
|
||||
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
||||
|
||||
|
||||
def _split_channels(num_chan, num_groups):
|
||||
split = [num_chan // num_groups for _ in range(num_groups)]
|
||||
split[0] += num_chan - sum(split)
|
||||
return split
|
||||
|
||||
|
||||
class Conv2dSame(nn.Conv2d):
|
||||
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True):
|
||||
super(Conv2dSame, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, 0, dilation,
|
||||
groups, bias)
|
||||
|
||||
def forward(self, x):
|
||||
ih, iw = x.size()[-2:]
|
||||
kh, kw = self.weight.size()[-2:]
|
||||
pad_h = _calc_same_pad(ih, kh, self.stride[0], self.dilation[0])
|
||||
pad_w = _calc_same_pad(iw, kw, self.stride[1], self.dilation[1])
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
|
||||
return F.conv2d(x, self.weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
def conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
||||
padding = kwargs.pop('padding', '')
|
||||
kwargs.setdefault('bias', False)
|
||||
if isinstance(padding, str):
|
||||
# for any string padding, the padding will be calculated for you, one of three ways
|
||||
padding = padding.lower()
|
||||
if padding == 'same':
|
||||
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
||||
if _is_static_pad(kernel_size, **kwargs):
|
||||
# static case, no extra overhead
|
||||
padding = _get_padding(kernel_size, **kwargs)
|
||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
||||
else:
|
||||
# dynamic padding
|
||||
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
|
||||
elif padding == 'valid':
|
||||
# 'VALID' padding, same as padding=0
|
||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs)
|
||||
else:
|
||||
# Default to PyTorch style 'same'-ish symmetric padding
|
||||
padding = _get_padding(kernel_size, **kwargs)
|
||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
||||
else:
|
||||
# padding was specified as a number or pair
|
||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
||||
|
||||
|
||||
class MixedConv2d(nn.Module):
|
||||
""" Mixed Grouped Convolution
|
||||
Based on MDConv and GroupedConv in MixNet impl:
|
||||
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||
stride=1, padding='', dilated=False, depthwise=False, **kwargs):
|
||||
super(MixedConv2d, self).__init__()
|
||||
|
||||
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
|
||||
num_groups = len(kernel_size)
|
||||
in_splits = _split_channels(in_channels, num_groups)
|
||||
out_splits = _split_channels(out_channels, num_groups)
|
||||
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
|
||||
d = 1
|
||||
# FIXME make compat with non-square kernel/dilations/strides
|
||||
if stride == 1 and dilated:
|
||||
d, k = (k - 1) // 2, 3
|
||||
conv_groups = out_ch if depthwise else 1
|
||||
# use add_module to keep key space clean
|
||||
self.add_module(
|
||||
str(idx),
|
||||
conv2d_pad(
|
||||
in_ch, out_ch, k, stride=stride,
|
||||
padding=padding, dilation=d, groups=conv_groups, **kwargs)
|
||||
)
|
||||
self.splits = in_splits
|
||||
|
||||
def forward(self, x):
|
||||
x_split = torch.split(x, self.splits, 1)
|
||||
x_out = [c(x) for x, c in zip(x_split, self._modules.values())]
|
||||
x = torch.cat(x_out, 1)
|
||||
return x
|
||||
|
||||
|
||||
# helper method
|
||||
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
||||
assert 'groups' not in kwargs # only use 'depthwise' bool arg
|
||||
if isinstance(kernel_size, list):
|
||||
# We're going to use only lists for defining the MixedConv2d kernel groups,
|
||||
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
|
||||
return MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
|
||||
else:
|
||||
depthwise = kwargs.pop('depthwise', False)
|
||||
groups = out_chs if depthwise else 1
|
||||
return conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
||||
|
@ -0,0 +1,260 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch._six import container_abcs
|
||||
from itertools import repeat
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
|
||||
# Tuple helpers ripped from PyTorch
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, container_abcs.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
return parse
|
||||
|
||||
|
||||
_single = _ntuple(1)
|
||||
_pair = _ntuple(2)
|
||||
_triple = _ntuple(3)
|
||||
_quadruple = _ntuple(4)
|
||||
|
||||
|
||||
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
|
||||
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
||||
|
||||
|
||||
def _get_padding(kernel_size, stride=1, dilation=1, **_):
|
||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||
return padding
|
||||
|
||||
|
||||
def _calc_same_pad(i, k, s, d):
|
||||
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
||||
|
||||
|
||||
def _split_channels(num_chan, num_groups):
|
||||
split = [num_chan // num_groups for _ in range(num_groups)]
|
||||
split[0] += num_chan - sum(split)
|
||||
return split
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def conv2d_same(x, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1):
|
||||
ih, iw = x.size()[-2:]
|
||||
kh, kw = weight.size()[-2:]
|
||||
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
|
||||
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
||||
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
|
||||
|
||||
|
||||
class Conv2dSame(nn.Conv2d):
|
||||
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True):
|
||||
super(Conv2dSame, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
||||
|
||||
def forward(self, x):
|
||||
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
def get_padding_value(padding, kernel_size, **kwargs):
|
||||
dynamic = False
|
||||
if isinstance(padding, str):
|
||||
# for any string padding, the padding will be calculated for you, one of three ways
|
||||
padding = padding.lower()
|
||||
if padding == 'same':
|
||||
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
||||
if _is_static_pad(kernel_size, **kwargs):
|
||||
# static case, no extra overhead
|
||||
padding = _get_padding(kernel_size, **kwargs)
|
||||
else:
|
||||
# dynamic 'SAME' padding, has runtime/GPU memory overhead
|
||||
padding = 0
|
||||
dynamic = True
|
||||
elif padding == 'valid':
|
||||
# 'VALID' padding, same as padding=0
|
||||
padding = 0
|
||||
else:
|
||||
# Default to PyTorch style 'same'-ish symmetric padding
|
||||
padding = _get_padding(kernel_size, **kwargs)
|
||||
return padding, dynamic
|
||||
|
||||
|
||||
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
||||
padding = kwargs.pop('padding', '')
|
||||
kwargs.setdefault('bias', False)
|
||||
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
|
||||
if is_dynamic:
|
||||
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
|
||||
else:
|
||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
||||
|
||||
|
||||
class MixedConv2d(nn.Module):
|
||||
""" Mixed Grouped Convolution
|
||||
Based on MDConv and GroupedConv in MixNet impl:
|
||||
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
||||
|
||||
NOTE: This does not currently work with torch.jit.script
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
|
||||
super(MixedConv2d, self).__init__()
|
||||
|
||||
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
|
||||
num_groups = len(kernel_size)
|
||||
in_splits = _split_channels(in_channels, num_groups)
|
||||
out_splits = _split_channels(out_channels, num_groups)
|
||||
self.in_channels = sum(in_splits)
|
||||
self.out_channels = sum(out_splits)
|
||||
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
|
||||
conv_groups = out_ch if depthwise else 1
|
||||
# use add_module to keep key space clean
|
||||
self.add_module(
|
||||
str(idx),
|
||||
create_conv2d_pad(
|
||||
in_ch, out_ch, k, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
|
||||
)
|
||||
self.splits = in_splits
|
||||
|
||||
def forward(self, x):
|
||||
x_split = torch.split(x, self.splits, 1)
|
||||
x_out = [c(x) for x, c in zip(x_split, self._modules.values())]
|
||||
x = torch.cat(x_out, 1)
|
||||
return x
|
||||
|
||||
|
||||
def get_condconv_initializer(initializer, num_experts, expert_shape):
|
||||
def condconv_initializer(weight):
|
||||
"""CondConv initializer function."""
|
||||
num_params = np.prod(expert_shape)
|
||||
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
|
||||
weight.shape[1] != num_params):
|
||||
raise (ValueError(
|
||||
'CondConv variables must have shape [num_experts, num_params]'))
|
||||
for i in range(num_experts):
|
||||
initializer(weight[i].view(expert_shape))
|
||||
return condconv_initializer
|
||||
|
||||
|
||||
class CondConv2d(nn.Module):
|
||||
""" Conditional Convolution
|
||||
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
|
||||
|
||||
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
|
||||
https://github.com/pytorch/pytorch/issues/17983
|
||||
"""
|
||||
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
|
||||
super(CondConv2d, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = _pair(kernel_size)
|
||||
self.stride = _pair(stride)
|
||||
padding_val, is_padding_dynamic = get_padding_value(
|
||||
padding, kernel_size, stride=stride, dilation=dilation)
|
||||
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
|
||||
self.padding = _pair(padding_val)
|
||||
self.dilation = _pair(dilation)
|
||||
self.groups = groups
|
||||
self.num_experts = num_experts
|
||||
|
||||
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
||||
weight_num_param = 1
|
||||
for wd in self.weight_shape:
|
||||
weight_num_param *= wd
|
||||
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
|
||||
|
||||
if bias:
|
||||
self.bias_shape = (self.out_channels,)
|
||||
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
init_weight = get_condconv_initializer(
|
||||
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
|
||||
init_weight(self.weight)
|
||||
if self.bias is not None:
|
||||
fan_in = np.prod(self.weight_shape[1:])
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
init_bias = get_condconv_initializer(
|
||||
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
|
||||
init_bias(self.bias)
|
||||
|
||||
def forward(self, x, routing_weights):
|
||||
B, C, H, W = x.shape
|
||||
weight = torch.matmul(routing_weights, self.weight)
|
||||
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
||||
weight = weight.view(new_weight_shape)
|
||||
bias = None
|
||||
if self.bias is not None:
|
||||
bias = torch.matmul(routing_weights, self.bias)
|
||||
bias = bias.view(B * self.out_channels)
|
||||
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
|
||||
x = x.view(1, B * C, H, W)
|
||||
if self.dynamic_padding:
|
||||
out = conv2d_same(
|
||||
x, weight, bias, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups * B)
|
||||
else:
|
||||
out = F.conv2d(
|
||||
x, weight, bias, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups * B)
|
||||
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
|
||||
|
||||
# Literal port (from TF definition)
|
||||
# x = torch.split(x, 1, 0)
|
||||
# weight = torch.split(weight, 1, 0)
|
||||
# if self.bias is not None:
|
||||
# bias = torch.matmul(routing_weights, self.bias)
|
||||
# bias = torch.split(bias, 1, 0)
|
||||
# else:
|
||||
# bias = [None] * B
|
||||
# out = []
|
||||
# for xi, wi, bi in zip(x, weight, bias):
|
||||
# wi = wi.view(*self.weight_shape)
|
||||
# if bi is not None:
|
||||
# bi = bi.view(*self.bias_shape)
|
||||
# out.append(self.conv_fn(
|
||||
# xi, wi, bi, stride=self.stride, padding=self.padding,
|
||||
# dilation=self.dilation, groups=self.groups))
|
||||
# out = torch.cat(out, 0)
|
||||
return out
|
||||
|
||||
|
||||
# helper method
|
||||
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
||||
assert 'groups' not in kwargs # only use 'depthwise' bool arg
|
||||
if isinstance(kernel_size, list):
|
||||
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
|
||||
# We're going to use only lists for defining the MixedConv2d kernel groups,
|
||||
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
|
||||
m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
|
||||
else:
|
||||
depthwise = kwargs.pop('depthwise', False)
|
||||
groups = out_chs if depthwise else 1
|
||||
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
|
||||
m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
||||
else:
|
||||
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
||||
return m
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,404 @@
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .activations import sigmoid
|
||||
from .conv2d_layers import *
|
||||
|
||||
|
||||
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
|
||||
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
|
||||
# NOTE: momentum varies btw .99 and .9997 depending on source
|
||||
# .99 in official TF TPU impl
|
||||
# .9997 (/w .999 in search space) for paper
|
||||
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
|
||||
BN_EPS_TF_DEFAULT = 1e-3
|
||||
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
|
||||
|
||||
|
||||
def get_bn_args_tf():
|
||||
return _BN_ARGS_TF.copy()
|
||||
|
||||
|
||||
def resolve_bn_args(kwargs):
|
||||
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
|
||||
bn_momentum = kwargs.pop('bn_momentum', None)
|
||||
if bn_momentum is not None:
|
||||
bn_args['momentum'] = bn_momentum
|
||||
bn_eps = kwargs.pop('bn_eps', None)
|
||||
if bn_eps is not None:
|
||||
bn_args['eps'] = bn_eps
|
||||
return bn_args
|
||||
|
||||
|
||||
_SE_ARGS_DEFAULT = dict(
|
||||
gate_fn=sigmoid,
|
||||
act_layer=None,
|
||||
reduce_mid=False,
|
||||
divisor=1)
|
||||
|
||||
|
||||
def resolve_se_args(kwargs, in_chs, act_layer=None):
|
||||
se_kwargs = kwargs.copy() if kwargs is not None else {}
|
||||
# fill in args that aren't specified with the defaults
|
||||
for k, v in _SE_ARGS_DEFAULT.items():
|
||||
se_kwargs.setdefault(k, v)
|
||||
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
|
||||
if not se_kwargs.pop('reduce_mid'):
|
||||
se_kwargs['reduced_base_chs'] = in_chs
|
||||
# act_layer override, if it remains None, the containing block's act_layer will be used
|
||||
if se_kwargs['act_layer'] is None:
|
||||
assert act_layer is not None
|
||||
se_kwargs['act_layer'] = act_layer
|
||||
return se_kwargs
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
min_value = min_value or divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
|
||||
"""Round number of filters based on depth multiplier."""
|
||||
if not multiplier:
|
||||
return channels
|
||||
channels *= multiplier
|
||||
return make_divisible(channels, divisor, channel_min)
|
||||
|
||||
|
||||
def drop_connect(inputs, training=False, drop_connect_rate=0.):
|
||||
"""Apply drop connect."""
|
||||
if not training:
|
||||
return inputs
|
||||
|
||||
keep_prob = 1 - drop_connect_rate
|
||||
random_tensor = keep_prob + torch.rand(
|
||||
(inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = inputs.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class ChannelShuffle(nn.Module):
|
||||
# FIXME haven't used yet
|
||||
def __init__(self, groups):
|
||||
super(ChannelShuffle, self).__init__()
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
"""Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]"""
|
||||
N, C, H, W = x.size()
|
||||
g = self.groups
|
||||
assert C % g == 0, "Incompatible group size {} for input channel {}".format(
|
||||
g, C
|
||||
)
|
||||
return (
|
||||
x.view(N, g, int(C / g), H, W)
|
||||
.permute(0, 2, 1, 3, 4)
|
||||
.contiguous()
|
||||
.view(N, C, H, W)
|
||||
)
|
||||
|
||||
|
||||
class SqueezeExcite(nn.Module):
|
||||
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
|
||||
act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_):
|
||||
super(SqueezeExcite, self).__init__()
|
||||
self.gate_fn = gate_fn
|
||||
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x_se = self.avg_pool(x)
|
||||
x_se = self.conv_reduce(x_se)
|
||||
x_se = self.act1(x_se)
|
||||
x_se = self.conv_expand(x_se)
|
||||
x = x * self.gate_fn(x_se)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBnAct(nn.Module):
|
||||
def __init__(self, in_chs, out_chs, kernel_size,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
super(ConvBnAct, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
|
||||
self.bn1 = norm_layer(out_chs, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
def feature_module(self, location):
|
||||
return 'act1'
|
||||
|
||||
def feature_channels(self, location):
|
||||
return self.conv.out_channels
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
return x
|
||||
|
||||
|
||||
class DepthwiseSeparableConv(nn.Module):
|
||||
""" DepthwiseSeparable block
|
||||
Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
|
||||
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
|
||||
"""
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
self.has_se = se_ratio is not None and se_ratio > 0.
|
||||
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.has_pw_act = pw_act # activation after point-wise conv
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
|
||||
self.conv_dw = select_conv2d(
|
||||
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
|
||||
self.bn1 = norm_layer(in_chs, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if self.has_se:
|
||||
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
|
||||
|
||||
self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
|
||||
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
||||
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
|
||||
|
||||
def feature_module(self, location):
|
||||
# no expansion in this block, pre pw only feature extraction point
|
||||
return 'conv_pw'
|
||||
|
||||
def feature_channels(self, location):
|
||||
return self.conv_pw.in_channels
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
if self.has_se:
|
||||
x = self.se(x)
|
||||
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn2(x)
|
||||
x = self.act2(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
x += residual
|
||||
return x
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
""" Inverted residual block w/ optional SE and CondConv routing"""
|
||||
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
conv_kwargs=None, drop_connect_rate=0.):
|
||||
super(InvertedResidual, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||
self.has_se = se_ratio is not None and se_ratio > 0.
|
||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
|
||||
# Point-wise expansion
|
||||
self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
|
||||
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Depth-wise convolution
|
||||
self.conv_dw = select_conv2d(
|
||||
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
|
||||
padding=pad_type, depthwise=True, **conv_kwargs)
|
||||
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if self.has_se:
|
||||
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
|
||||
|
||||
# Point-wise linear projection
|
||||
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
||||
self.bn3 = norm_layer(out_chs, **norm_kwargs)
|
||||
|
||||
def feature_module(self, location):
|
||||
if location == 'post_exp':
|
||||
return 'act1'
|
||||
return 'conv_pwl'
|
||||
|
||||
def feature_channels(self, location):
|
||||
if location == 'post_exp':
|
||||
return self.conv_pw.out_channels
|
||||
# location == 'pre_pw'
|
||||
return self.conv_pwl.in_channels
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
# Point-wise expansion
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
# Depth-wise convolution
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn2(x)
|
||||
x = self.act2(x)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if self.has_se:
|
||||
x = self.se(x)
|
||||
|
||||
# Point-wise linear projection
|
||||
x = self.conv_pwl(x)
|
||||
x = self.bn3(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
x += residual
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CondConvResidual(InvertedResidual):
|
||||
""" Inverted residual block w/ CondConv routing"""
|
||||
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
num_experts=0, drop_connect_rate=0.):
|
||||
|
||||
self.num_experts = num_experts
|
||||
conv_kwargs = dict(num_experts=self.num_experts)
|
||||
|
||||
super(CondConvResidual, self).__init__(
|
||||
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
|
||||
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
|
||||
drop_connect_rate=drop_connect_rate)
|
||||
|
||||
self.routing_fn = nn.Linear(in_chs, self.num_experts)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
# CondConv routing
|
||||
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
|
||||
routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))
|
||||
|
||||
# Point-wise expansion
|
||||
x = self.conv_pw(x, routing_weights)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
# Depth-wise convolution
|
||||
x = self.conv_dw(x, routing_weights)
|
||||
x = self.bn2(x)
|
||||
x = self.act2(x)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if self.has_se:
|
||||
x = self.se(x)
|
||||
|
||||
# Point-wise linear projection
|
||||
x = self.conv_pwl(x, routing_weights)
|
||||
x = self.bn3(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
x += residual
|
||||
return x
|
||||
|
||||
|
||||
class EdgeResidual(nn.Module):
|
||||
""" Residual block with expansion convolution followed by pointwise-linear w/ stride"""
|
||||
|
||||
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
drop_connect_rate=0.):
|
||||
super(EdgeResidual, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
if fake_in_chs > 0:
|
||||
mid_chs = make_divisible(fake_in_chs * exp_ratio)
|
||||
else:
|
||||
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||
self.has_se = se_ratio is not None and se_ratio > 0.
|
||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
|
||||
# Expansion convolution
|
||||
self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
|
||||
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if self.has_se:
|
||||
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
|
||||
|
||||
# Point-wise linear projection
|
||||
self.conv_pwl = select_conv2d(
|
||||
mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
|
||||
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
||||
|
||||
def feature_module(self, location):
|
||||
if location == 'post_exp':
|
||||
return 'act1'
|
||||
return 'conv_pwl'
|
||||
|
||||
def feature_channels(self, location):
|
||||
if location == 'post_exp':
|
||||
return self.conv_exp.out_channels
|
||||
# location == 'pre_pw'
|
||||
return self.conv_pwl.in_channels
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
# Expansion convolution
|
||||
x = self.conv_exp(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if self.has_se:
|
||||
x = self.se(x)
|
||||
|
||||
# Point-wise linear projection
|
||||
x = self.conv_pwl(x)
|
||||
x = self.bn2(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
x += residual
|
||||
|
||||
return x
|
@ -0,0 +1,402 @@
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from collections.__init__ import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import torch.nn as nn
|
||||
from .activations import sigmoid, HardSwish, Swish
|
||||
from .efficientnet_blocks import *
|
||||
|
||||
|
||||
def _parse_ksize(ss):
|
||||
if ss.isdigit():
|
||||
return int(ss)
|
||||
else:
|
||||
return [int(k) for k in ss.split('.')]
|
||||
|
||||
|
||||
def _decode_block_str(block_str):
|
||||
""" Decode block definition string
|
||||
|
||||
Gets a list of block arg (dicts) through a string notation of arguments.
|
||||
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
|
||||
|
||||
All args can exist in any order with the exception of the leading string which
|
||||
is assumed to indicate the block type.
|
||||
|
||||
leading string - block type (
|
||||
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
|
||||
r - number of repeat blocks,
|
||||
k - kernel size,
|
||||
s - strides (1-9),
|
||||
e - expansion ratio,
|
||||
c - output channels,
|
||||
se - squeeze/excitation ratio
|
||||
n - activation fn ('re', 'r6', 'hs', or 'sw')
|
||||
Args:
|
||||
block_str: a string representation of block arguments.
|
||||
Returns:
|
||||
A list of block args (dicts)
|
||||
Raises:
|
||||
ValueError: if the string def not properly specified (TODO)
|
||||
"""
|
||||
assert isinstance(block_str, str)
|
||||
ops = block_str.split('_')
|
||||
block_type = ops[0] # take the block type off the front
|
||||
ops = ops[1:]
|
||||
options = {}
|
||||
noskip = False
|
||||
for op in ops:
|
||||
# string options being checked on individual basis, combine if they grow
|
||||
if op == 'noskip':
|
||||
noskip = True
|
||||
elif op.startswith('n'):
|
||||
# activation fn
|
||||
key = op[0]
|
||||
v = op[1:]
|
||||
if v == 're':
|
||||
value = nn.ReLU
|
||||
elif v == 'r6':
|
||||
value = nn.ReLU6
|
||||
elif v == 'hs':
|
||||
value = HardSwish
|
||||
elif v == 'sw':
|
||||
value = Swish
|
||||
else:
|
||||
continue
|
||||
options[key] = value
|
||||
else:
|
||||
# all numeric options
|
||||
splits = re.split(r'(\d.*)', op)
|
||||
if len(splits) >= 2:
|
||||
key, value = splits[:2]
|
||||
options[key] = value
|
||||
|
||||
# if act_layer is None, the model default (passed to model init) will be used
|
||||
act_layer = options['n'] if 'n' in options else None
|
||||
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
|
||||
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
|
||||
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
|
||||
|
||||
num_repeat = int(options['r'])
|
||||
# each type of block has different valid arguments, fill accordingly
|
||||
if block_type == 'ir':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
exp_ratio=float(options['e']),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
noskip=noskip,
|
||||
)
|
||||
if 'cc' in options:
|
||||
block_args['num_experts'] = int(options['cc'])
|
||||
elif block_type == 'ds' or block_type == 'dsa':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
pw_act=block_type == 'dsa',
|
||||
noskip=block_type == 'dsa' or noskip,
|
||||
)
|
||||
elif block_type == 'er':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
exp_kernel_size=_parse_ksize(options['k']),
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
exp_ratio=float(options['e']),
|
||||
fake_in_chs=fake_in_chs,
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
noskip=noskip,
|
||||
)
|
||||
elif block_type == 'cn':
|
||||
block_args = dict(
|
||||
block_type=block_type,
|
||||
kernel_size=int(options['k']),
|
||||
out_chs=int(options['c']),
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
)
|
||||
else:
|
||||
assert False, 'Unknown block type (%s)' % block_type
|
||||
|
||||
return block_args, num_repeat
|
||||
|
||||
|
||||
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
|
||||
""" Per-stage depth scaling
|
||||
Scales the block repeats in each stage. This depth scaling impl maintains
|
||||
compatibility with the EfficientNet scaling method, while allowing sensible
|
||||
scaling for other models that may have multiple block arg definitions in each stage.
|
||||
"""
|
||||
|
||||
# We scale the total repeat count for each stage, there may be multiple
|
||||
# block arg defs per stage so we need to sum.
|
||||
num_repeat = sum(repeats)
|
||||
if depth_trunc == 'round':
|
||||
# Truncating to int by rounding allows stages with few repeats to remain
|
||||
# proportionally smaller for longer. This is a good choice when stage definitions
|
||||
# include single repeat stages that we'd prefer to keep that way as long as possible
|
||||
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
|
||||
else:
|
||||
# The default for EfficientNet truncates repeats to int via 'ceil'.
|
||||
# Any multiplier > 1.0 will result in an increased depth for every stage.
|
||||
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
|
||||
|
||||
# Proportionally distribute repeat count scaling to each block definition in the stage.
|
||||
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
|
||||
# The first block makes less sense to repeat in most of the arch definitions.
|
||||
repeats_scaled = []
|
||||
for r in repeats[::-1]:
|
||||
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
|
||||
repeats_scaled.append(rs)
|
||||
num_repeat -= r
|
||||
num_repeat_scaled -= rs
|
||||
repeats_scaled = repeats_scaled[::-1]
|
||||
|
||||
# Apply the calculated scaling to each block arg in the stage
|
||||
sa_scaled = []
|
||||
for ba, rep in zip(stack_args, repeats_scaled):
|
||||
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
|
||||
return sa_scaled
|
||||
|
||||
|
||||
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1):
|
||||
arch_args = []
|
||||
for stack_idx, block_strings in enumerate(arch_def):
|
||||
assert isinstance(block_strings, list)
|
||||
stack_args = []
|
||||
repeats = []
|
||||
for block_str in block_strings:
|
||||
assert isinstance(block_str, str)
|
||||
ba, rep = _decode_block_str(block_str)
|
||||
if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
|
||||
ba['num_experts'] *= experts_multiplier
|
||||
stack_args.append(ba)
|
||||
repeats.append(rep)
|
||||
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
|
||||
return arch_args
|
||||
|
||||
|
||||
class EfficientNetBuilder:
|
||||
""" Build Trunk Blocks
|
||||
|
||||
This ended up being somewhat of a cross between
|
||||
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
|
||||
and
|
||||
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
|
||||
|
||||
"""
|
||||
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
output_stride=32, pad_type='', act_layer=None, se_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0., feature_location='',
|
||||
verbose=False):
|
||||
self.channel_multiplier = channel_multiplier
|
||||
self.channel_divisor = channel_divisor
|
||||
self.channel_min = channel_min
|
||||
self.output_stride = output_stride
|
||||
self.pad_type = pad_type
|
||||
self.act_layer = act_layer
|
||||
self.se_kwargs = se_kwargs
|
||||
self.norm_layer = norm_layer
|
||||
self.norm_kwargs = norm_kwargs
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
self.feature_location = feature_location
|
||||
assert feature_location in ('pre_pwl', 'post_exp', '')
|
||||
self.verbose = verbose
|
||||
|
||||
# state updated during build, consumed by model
|
||||
self.in_chs = None
|
||||
self.features = OrderedDict()
|
||||
|
||||
def _round_channels(self, chs):
|
||||
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
|
||||
|
||||
def _make_block(self, ba, block_idx, block_count):
|
||||
drop_connect_rate = self.drop_connect_rate * block_idx / block_count
|
||||
bt = ba.pop('block_type')
|
||||
ba['in_chs'] = self.in_chs
|
||||
ba['out_chs'] = self._round_channels(ba['out_chs'])
|
||||
if 'fake_in_chs' in ba and ba['fake_in_chs']:
|
||||
# FIXME this is a hack to work around mismatch in origin impl input filters
|
||||
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
|
||||
ba['norm_layer'] = self.norm_layer
|
||||
ba['norm_kwargs'] = self.norm_kwargs
|
||||
ba['pad_type'] = self.pad_type
|
||||
# block act fn overrides the model default
|
||||
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
|
||||
assert ba['act_layer'] is not None
|
||||
if bt == 'ir':
|
||||
ba['drop_connect_rate'] = drop_connect_rate
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
if self.verbose:
|
||||
logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)))
|
||||
if ba.get('num_experts', 0) > 0:
|
||||
block = CondConvResidual(**ba)
|
||||
else:
|
||||
block = InvertedResidual(**ba)
|
||||
elif bt == 'ds' or bt == 'dsa':
|
||||
ba['drop_connect_rate'] = drop_connect_rate
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
if self.verbose:
|
||||
logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)))
|
||||
block = DepthwiseSeparableConv(**ba)
|
||||
elif bt == 'er':
|
||||
ba['drop_connect_rate'] = drop_connect_rate
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
if self.verbose:
|
||||
logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)))
|
||||
block = EdgeResidual(**ba)
|
||||
elif bt == 'cn':
|
||||
if self.verbose:
|
||||
logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)))
|
||||
block = ConvBnAct(**ba)
|
||||
else:
|
||||
assert False, 'Uknkown block type (%s) while building model.' % bt
|
||||
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
|
||||
|
||||
return block
|
||||
|
||||
def __call__(self, in_chs, model_block_args):
|
||||
""" Build the blocks
|
||||
Args:
|
||||
in_chs: Number of input-channels passed to first block
|
||||
model_block_args: A list of lists, outer list defines stages, inner
|
||||
list contains strings defining block configuration(s)
|
||||
Return:
|
||||
List of block stacks (each stack wrapped in nn.Sequential)
|
||||
"""
|
||||
if self.verbose:
|
||||
logging.info('Building model trunk with %d stages...' % len(model_block_args))
|
||||
self.in_chs = in_chs
|
||||
total_block_count = sum([len(x) for x in model_block_args])
|
||||
total_block_idx = 0
|
||||
current_stride = 2
|
||||
current_dilation = 1
|
||||
feature_idx = 0
|
||||
stages = []
|
||||
# outer list of block_args defines the stacks ('stages' by some conventions)
|
||||
for stage_idx, stage_block_args in enumerate(model_block_args):
|
||||
last_stack = stage_idx == (len(model_block_args) - 1)
|
||||
if self.verbose:
|
||||
logging.info('Stack: {}'.format(stage_idx))
|
||||
assert isinstance(stage_block_args, list)
|
||||
|
||||
blocks = []
|
||||
# each stack (stage) contains a list of block arguments
|
||||
for block_idx, block_args in enumerate(stage_block_args):
|
||||
last_block = block_idx == (len(stage_block_args) - 1)
|
||||
extract_features = '' # No features extracted
|
||||
if self.verbose:
|
||||
logging.info(' Block: {}'.format(block_idx))
|
||||
|
||||
# Sort out stride, dilation, and feature extraction details
|
||||
assert block_args['stride'] in (1, 2)
|
||||
if block_idx >= 1:
|
||||
# only the first block in any stack can have a stride > 1
|
||||
block_args['stride'] = 1
|
||||
|
||||
do_extract = False
|
||||
if self.feature_location == 'pre_pwl':
|
||||
if last_block:
|
||||
next_stage_idx = stage_idx + 1
|
||||
if next_stage_idx >= len(model_block_args):
|
||||
do_extract = True
|
||||
else:
|
||||
do_extract = model_block_args[next_stage_idx][0]['stride'] > 1
|
||||
elif self.feature_location == 'post_exp':
|
||||
if block_args['stride'] > 1 or (last_stack and last_block) :
|
||||
do_extract = True
|
||||
if do_extract:
|
||||
extract_features = self.feature_location
|
||||
|
||||
next_dilation = current_dilation
|
||||
if block_args['stride'] > 1:
|
||||
next_output_stride = current_stride * block_args['stride']
|
||||
if next_output_stride > self.output_stride:
|
||||
next_dilation = current_dilation * block_args['stride']
|
||||
block_args['stride'] = 1
|
||||
if self.verbose:
|
||||
logging.info(' Converting stride to dilation to maintain output_stride=={}'.format(
|
||||
self.output_stride))
|
||||
else:
|
||||
current_stride = next_output_stride
|
||||
block_args['dilation'] = current_dilation
|
||||
if next_dilation != current_dilation:
|
||||
current_dilation = next_dilation
|
||||
|
||||
# create the block
|
||||
block = self._make_block(block_args, total_block_idx, total_block_count)
|
||||
blocks.append(block)
|
||||
|
||||
# stash feature module name and channel info for model feature extraction
|
||||
if extract_features:
|
||||
feature_module = block.feature_module(extract_features)
|
||||
if feature_module:
|
||||
feature_module = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_module
|
||||
feature_channels = block.feature_channels(extract_features)
|
||||
self.features[feature_idx] = dict(
|
||||
name=feature_module,
|
||||
num_chs=feature_channels
|
||||
)
|
||||
feature_idx += 1
|
||||
|
||||
total_block_idx += 1 # incr global block idx (across all stacks)
|
||||
stages.append(nn.Sequential(*blocks))
|
||||
return stages
|
||||
|
||||
|
||||
def efficientnet_init_goog(m, n=''):
|
||||
# weight init as per Tensorflow Official impl
|
||||
# https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
|
||||
if isinstance(m, CondConv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
init_weight_fn = get_condconv_initializer(
|
||||
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
|
||||
init_weight_fn(m.weight)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1.0)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
fan_out = m.weight.size(0) # fan-out
|
||||
fan_in = 0
|
||||
if 'routing_fn' in n:
|
||||
fan_in = m.weight.size(1)
|
||||
init_range = 1.0 / math.sqrt(fan_in + fan_out)
|
||||
m.weight.data.uniform_(-init_range, init_range)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
def efficientnet_init_default(m, n=''):
|
||||
if isinstance(m, CondConv2d):
|
||||
init_fn = get_condconv_initializer(partial(
|
||||
nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape)
|
||||
init_fn(m.weight)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1.0)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
|
||||
|
||||
|
@ -0,0 +1,31 @@
|
||||
from collections import defaultdict, OrderedDict
|
||||
from functools import partial
|
||||
|
||||
|
||||
class FeatureHooks:
|
||||
|
||||
def __init__(self, hooks, named_modules):
|
||||
# setup feature hooks
|
||||
modules = {k: v for k, v in named_modules}
|
||||
for h in hooks:
|
||||
hook_name = h['name']
|
||||
m = modules[hook_name]
|
||||
hook_fn = partial(self._collect_output_hook, hook_name)
|
||||
if h['type'] == 'forward_pre':
|
||||
m.register_forward_pre_hook(hook_fn)
|
||||
elif h['type'] == 'forward':
|
||||
m.register_forward_hook(hook_fn)
|
||||
else:
|
||||
assert False, "Unsupported hook type"
|
||||
self._feature_outputs = defaultdict(OrderedDict)
|
||||
|
||||
def _collect_output_hook(self, name, *args):
|
||||
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
||||
if isinstance(x, tuple):
|
||||
x = x[0] # unwrap input tuple
|
||||
self._feature_outputs[x.device][name] = x
|
||||
|
||||
def get_output(self, device):
|
||||
output = tuple(self._feature_outputs[device].values())[::-1]
|
||||
self._feature_outputs[device] = OrderedDict() # clear after reading
|
||||
return output
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,869 @@
|
||||
""" HRNet
|
||||
|
||||
Copied from https://github.com/HRNet/HRNet-Image-Classification
|
||||
|
||||
Original header:
|
||||
Copyright (c) Microsoft
|
||||
Licensed under the MIT License.
|
||||
Written by Bin Xiao (Bin.Xiao@microsoft.com)
|
||||
Modified by Ke Sun (sunk@mail.ustc.edu.cn)
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import logging
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch._utils
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .helpers import load_pretrained
|
||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
BN_MOMENTUM = 0.1
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv1', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'hrnet_w18_small': _cfg(url=''),
|
||||
'hrnet_w18_small_v2': _cfg(url=''),
|
||||
'hrnet_w18': _cfg(url=''),
|
||||
'hrnet_w30': _cfg(url=''),
|
||||
'hrnet_w32': _cfg(url=''),
|
||||
'hrnet_w40': _cfg(url=''),
|
||||
'hrnet_w44': _cfg(url=''),
|
||||
'hrnet_w48': _cfg(url=''),
|
||||
}
|
||||
|
||||
cfg_cls_hrnet_w18_small = dict(
|
||||
STAGE1=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=1,
|
||||
BLOCK='BOTTLENECK',
|
||||
NUM_BLOCKS=(1,),
|
||||
NUM_CHANNELS=(32,),
|
||||
FUSE_METHOD='SUM',
|
||||
),
|
||||
STAGE2=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=2,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(2, 2),
|
||||
NUM_CHANNELS=(16, 32),
|
||||
FUSE_METHOD='SUM'
|
||||
),
|
||||
STAGE3=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=3,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(2, 2, 2),
|
||||
NUM_CHANNELS=(16, 32, 64),
|
||||
FUSE_METHOD='SUM'
|
||||
),
|
||||
STAGE4=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=4,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(2, 2, 2, 2),
|
||||
NUM_CHANNELS=(16, 32, 64, 128),
|
||||
FUSE_METHOD='SUM',
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
cfg_cls_hrnet_w18_small_v2 = dict(
|
||||
STAGE1=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=1,
|
||||
BLOCK='BOTTLENECK',
|
||||
NUM_BLOCKS=(2,),
|
||||
NUM_CHANNELS=(64,),
|
||||
FUSE_METHOD='SUM',
|
||||
),
|
||||
STAGE2=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=2,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(2, 2),
|
||||
NUM_CHANNELS=(18, 36),
|
||||
FUSE_METHOD='SUM'
|
||||
),
|
||||
STAGE3=dict(
|
||||
NUM_MODULES=3,
|
||||
NUM_BRANCHES=3,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(2, 2, 2),
|
||||
NUM_CHANNELS=(18, 36, 72),
|
||||
FUSE_METHOD='SUM'
|
||||
),
|
||||
STAGE4=dict(
|
||||
NUM_MODULES=2,
|
||||
NUM_BRANCHES=4,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(2, 2, 2, 2),
|
||||
NUM_CHANNELS=(18, 36, 72, 144),
|
||||
FUSE_METHOD='SUM',
|
||||
),
|
||||
)
|
||||
|
||||
cfg_cls_hrnet_w18 = dict(
|
||||
STAGE1=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=1,
|
||||
BLOCK='BOTTLENECK',
|
||||
NUM_BLOCKS=(4,),
|
||||
NUM_CHANNELS=(64,),
|
||||
FUSE_METHOD='SUM',
|
||||
),
|
||||
STAGE2=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=2,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4),
|
||||
NUM_CHANNELS=(18, 36),
|
||||
FUSE_METHOD='SUM'
|
||||
),
|
||||
STAGE3=dict(
|
||||
NUM_MODULES=4,
|
||||
NUM_BRANCHES=3,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4, 4),
|
||||
NUM_CHANNELS=(18, 36, 72),
|
||||
FUSE_METHOD='SUM'
|
||||
),
|
||||
STAGE4=dict(
|
||||
NUM_MODULES=3,
|
||||
NUM_BRANCHES=4,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4, 4, 4),
|
||||
NUM_CHANNELS=(18, 36, 72, 144),
|
||||
FUSE_METHOD='SUM',
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
cfg_cls_hrnet_w30 = dict(
|
||||
STAGE1=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=1,
|
||||
BLOCK='BOTTLENECK',
|
||||
NUM_BLOCKS=(4,),
|
||||
NUM_CHANNELS=(64,),
|
||||
FUSE_METHOD='SUM',
|
||||
),
|
||||
STAGE2=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=2,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4),
|
||||
NUM_CHANNELS=(30, 60),
|
||||
FUSE_METHOD='SUM'
|
||||
),
|
||||
STAGE3=dict(
|
||||
NUM_MODULES=4,
|
||||
NUM_BRANCHES=3,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4, 4),
|
||||
NUM_CHANNELS=(30, 60, 120),
|
||||
FUSE_METHOD='SUM'
|
||||
),
|
||||
STAGE4=dict(
|
||||
NUM_MODULES=3,
|
||||
NUM_BRANCHES=4,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4, 4, 4),
|
||||
NUM_CHANNELS=(30, 60, 120, 240),
|
||||
FUSE_METHOD='SUM',
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
cfg_cls_hrnet_w32 = dict(
|
||||
STAGE1=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=1,
|
||||
BLOCK='BOTTLENECK',
|
||||
NUM_BLOCKS=(4,),
|
||||
NUM_CHANNELS=(64,),
|
||||
FUSE_METHOD='SUM',
|
||||
),
|
||||
STAGE2=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=2,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4),
|
||||
NUM_CHANNELS=(32, 64),
|
||||
FUSE_METHOD='SUM'
|
||||
),
|
||||
STAGE3=dict(
|
||||
NUM_MODULES=4,
|
||||
NUM_BRANCHES=3,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4, 4),
|
||||
NUM_CHANNELS=(32, 64, 128),
|
||||
FUSE_METHOD='SUM'
|
||||
),
|
||||
STAGE4=dict(
|
||||
NUM_MODULES=3,
|
||||
NUM_BRANCHES=4,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4, 4, 4),
|
||||
NUM_CHANNELS=(32, 64, 128, 256),
|
||||
FUSE_METHOD='SUM',
|
||||
),
|
||||
)
|
||||
|
||||
cfg_cls_hrnet_w40 = dict(
|
||||
STAGE1=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=1,
|
||||
BLOCK='BOTTLENECK',
|
||||
NUM_BLOCKS=(4,),
|
||||
NUM_CHANNELS=(64,),
|
||||
FUSE_METHOD='SUM',
|
||||
),
|
||||
STAGE2=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=2,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4),
|
||||
NUM_CHANNELS=(40, 80),
|
||||
FUSE_METHOD='SUM'
|
||||
),
|
||||
STAGE3=dict(
|
||||
NUM_MODULES=4,
|
||||
NUM_BRANCHES=3,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4, 4),
|
||||
NUM_CHANNELS=(40, 80, 160),
|
||||
FUSE_METHOD='SUM'
|
||||
),
|
||||
STAGE4=dict(
|
||||
NUM_MODULES=3,
|
||||
NUM_BRANCHES=4,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4, 4, 4),
|
||||
NUM_CHANNELS=(40, 80, 160, 320),
|
||||
FUSE_METHOD='SUM',
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
cfg_cls_hrnet_w44 = dict(
|
||||
STAGE1=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=1,
|
||||
BLOCK='BOTTLENECK',
|
||||
NUM_BLOCKS=(4,),
|
||||
NUM_CHANNELS=(64,),
|
||||
FUSE_METHOD='SUM',
|
||||
),
|
||||
STAGE2=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=2,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4),
|
||||
NUM_CHANNELS=(44, 88),
|
||||
FUSE_METHOD='SUM'
|
||||
),
|
||||
STAGE3=dict(
|
||||
NUM_MODULES=4,
|
||||
NUM_BRANCHES=3,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4, 4),
|
||||
NUM_CHANNELS=(44, 88, 176),
|
||||
FUSE_METHOD='SUM'
|
||||
),
|
||||
STAGE4=dict(
|
||||
NUM_MODULES=3,
|
||||
NUM_BRANCHES=4,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4, 4, 4),
|
||||
NUM_CHANNELS=(44, 88, 176, 352),
|
||||
FUSE_METHOD='SUM',
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
cfg_cls_hrnet_w48 = dict(
|
||||
STAGE1=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=1,
|
||||
BLOCK='BOTTLENECK',
|
||||
NUM_BLOCKS=(4,),
|
||||
NUM_CHANNELS=(64,),
|
||||
FUSE_METHOD='SUM',
|
||||
),
|
||||
STAGE2=dict(
|
||||
NUM_MODULES=1,
|
||||
NUM_BRANCHES=2,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4),
|
||||
NUM_CHANNELS=(48, 96),
|
||||
FUSE_METHOD='SUM'
|
||||
),
|
||||
STAGE3=dict(
|
||||
NUM_MODULES=4,
|
||||
NUM_BRANCHES=3,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4, 4),
|
||||
NUM_CHANNELS=(48, 96, 192),
|
||||
FUSE_METHOD='SUM'
|
||||
),
|
||||
STAGE4=dict(
|
||||
NUM_MODULES=3,
|
||||
NUM_BRANCHES=4,
|
||||
BLOCK='BASIC',
|
||||
NUM_BLOCKS=(4, 4, 4, 4),
|
||||
NUM_CHANNELS=(48, 96, 192, 384),
|
||||
FUSE_METHOD='SUM',
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||
self.conv3 = nn.Conv2d(
|
||||
planes, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(
|
||||
planes * self.expansion, momentum=BN_MOMENTUM)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class HighResolutionModule(nn.Module):
|
||||
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
|
||||
num_channels, fuse_method, multi_scale_output=True):
|
||||
super(HighResolutionModule, self).__init__()
|
||||
self._check_branches(
|
||||
num_branches, blocks, num_blocks, num_inchannels, num_channels)
|
||||
|
||||
self.num_inchannels = num_inchannels
|
||||
self.fuse_method = fuse_method
|
||||
self.num_branches = num_branches
|
||||
|
||||
self.multi_scale_output = multi_scale_output
|
||||
|
||||
self.branches = self._make_branches(
|
||||
num_branches, blocks, num_blocks, num_channels)
|
||||
self.fuse_layers = self._make_fuse_layers()
|
||||
self.relu = nn.ReLU(False)
|
||||
|
||||
def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels):
|
||||
if num_branches != len(num_blocks):
|
||||
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
|
||||
num_branches, len(num_blocks))
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if num_branches != len(num_channels):
|
||||
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
|
||||
num_branches, len(num_channels))
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if num_branches != len(num_inchannels):
|
||||
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
|
||||
num_branches, len(num_inchannels))
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
|
||||
stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or \
|
||||
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample))
|
||||
self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
|
||||
for i in range(1, num_blocks[branch_index]):
|
||||
layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
||||
branches = []
|
||||
|
||||
for i in range(num_branches):
|
||||
branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
|
||||
|
||||
return nn.ModuleList(branches)
|
||||
|
||||
def _make_fuse_layers(self):
|
||||
if self.num_branches == 1:
|
||||
return None
|
||||
|
||||
num_branches = self.num_branches
|
||||
num_inchannels = self.num_inchannels
|
||||
fuse_layers = []
|
||||
for i in range(num_branches if self.multi_scale_output else 1):
|
||||
fuse_layer = []
|
||||
for j in range(num_branches):
|
||||
if j > i:
|
||||
fuse_layer.append(nn.Sequential(
|
||||
nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM),
|
||||
nn.Upsample(scale_factor=2 ** (j - i), mode='nearest')))
|
||||
elif j == i:
|
||||
fuse_layer.append(None)
|
||||
else:
|
||||
conv3x3s = []
|
||||
for k in range(i - j):
|
||||
if k == i - j - 1:
|
||||
num_outchannels_conv3x3 = num_inchannels[i]
|
||||
conv3x3s.append(nn.Sequential(
|
||||
nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM)))
|
||||
else:
|
||||
num_outchannels_conv3x3 = num_inchannels[j]
|
||||
conv3x3s.append(nn.Sequential(
|
||||
nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM),
|
||||
nn.ReLU(False)))
|
||||
fuse_layer.append(nn.Sequential(*conv3x3s))
|
||||
fuse_layers.append(nn.ModuleList(fuse_layer))
|
||||
|
||||
return nn.ModuleList(fuse_layers)
|
||||
|
||||
def get_num_inchannels(self):
|
||||
return self.num_inchannels
|
||||
|
||||
def forward(self, x):
|
||||
if self.num_branches == 1:
|
||||
return [self.branches[0](x[0])]
|
||||
|
||||
for i in range(self.num_branches):
|
||||
x[i] = self.branches[i](x[i])
|
||||
|
||||
x_fuse = []
|
||||
for i in range(len(self.fuse_layers)):
|
||||
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
||||
for j in range(1, self.num_branches):
|
||||
if i == j:
|
||||
y = y + x[j]
|
||||
else:
|
||||
y = y + self.fuse_layers[i][j](x[j])
|
||||
x_fuse.append(self.relu(y))
|
||||
|
||||
return x_fuse
|
||||
|
||||
|
||||
blocks_dict = {
|
||||
'BASIC': BasicBlock,
|
||||
'BOTTLENECK': Bottleneck
|
||||
}
|
||||
|
||||
|
||||
class HighResolutionNet(nn.Module):
|
||||
|
||||
def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg'):
|
||||
super(HighResolutionNet, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
||||
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.stage1_cfg = cfg['STAGE1']
|
||||
num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
|
||||
block = blocks_dict[self.stage1_cfg['BLOCK']]
|
||||
num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
|
||||
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
|
||||
stage1_out_channel = block.expansion * num_channels
|
||||
|
||||
self.stage2_cfg = cfg['STAGE2']
|
||||
num_channels = self.stage2_cfg['NUM_CHANNELS']
|
||||
block = blocks_dict[self.stage2_cfg['BLOCK']]
|
||||
num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
|
||||
self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels)
|
||||
self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels)
|
||||
|
||||
self.stage3_cfg = cfg['STAGE3']
|
||||
num_channels = self.stage3_cfg['NUM_CHANNELS']
|
||||
block = blocks_dict[self.stage3_cfg['BLOCK']]
|
||||
num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
|
||||
self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
|
||||
self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels)
|
||||
|
||||
self.stage4_cfg = cfg['STAGE4']
|
||||
num_channels = self.stage4_cfg['NUM_CHANNELS']
|
||||
block = blocks_dict[self.stage4_cfg['BLOCK']]
|
||||
num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
|
||||
self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
|
||||
self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True)
|
||||
|
||||
# Classification Head
|
||||
self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head(pre_stage_channels)
|
||||
|
||||
self.classifier = nn.Linear(2048, num_classes)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _make_head(self, pre_stage_channels):
|
||||
head_block = Bottleneck
|
||||
head_channels = [32, 64, 128, 256]
|
||||
|
||||
# Increasing the #channels on each resolution
|
||||
# from C, 2C, 4C, 8C to 128, 256, 512, 1024
|
||||
incre_modules = []
|
||||
for i, channels in enumerate(pre_stage_channels):
|
||||
incre_modules.append(
|
||||
self._make_layer(head_block, channels, head_channels[i], 1, stride=1))
|
||||
incre_modules = nn.ModuleList(incre_modules)
|
||||
|
||||
# downsampling modules
|
||||
downsamp_modules = []
|
||||
for i in range(len(pre_stage_channels) - 1):
|
||||
in_channels = head_channels[i] * head_block.expansion
|
||||
out_channels = head_channels[i + 1] * head_block.expansion
|
||||
downsamp_module = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
downsamp_modules.append(downsamp_module)
|
||||
downsamp_modules = nn.ModuleList(downsamp_modules)
|
||||
|
||||
final_layer = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=head_channels[3] * head_block.expansion,
|
||||
out_channels=2048, kernel_size=1, stride=1, padding=0
|
||||
),
|
||||
nn.BatchNorm2d(2048, momentum=BN_MOMENTUM),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
return incre_modules, downsamp_modules, final_layer
|
||||
|
||||
def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
|
||||
num_branches_cur = len(num_channels_cur_layer)
|
||||
num_branches_pre = len(num_channels_pre_layer)
|
||||
|
||||
transition_layers = []
|
||||
for i in range(num_branches_cur):
|
||||
if i < num_branches_pre:
|
||||
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
||||
transition_layers.append(nn.Sequential(
|
||||
nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False),
|
||||
nn.BatchNorm2d(num_channels_cur_layer[i], momentum=BN_MOMENTUM),
|
||||
nn.ReLU(inplace=True)))
|
||||
else:
|
||||
transition_layers.append(None)
|
||||
else:
|
||||
conv3x3s = []
|
||||
for j in range(i + 1 - num_branches_pre):
|
||||
inchannels = num_channels_pre_layer[-1]
|
||||
outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels
|
||||
conv3x3s.append(nn.Sequential(
|
||||
nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
|
||||
nn.ReLU(inplace=True)))
|
||||
transition_layers.append(nn.Sequential(*conv3x3s))
|
||||
|
||||
return nn.ModuleList(transition_layers)
|
||||
|
||||
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(inplanes, planes, stride, downsample))
|
||||
inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True):
|
||||
num_modules = layer_config['NUM_MODULES']
|
||||
num_branches = layer_config['NUM_BRANCHES']
|
||||
num_blocks = layer_config['NUM_BLOCKS']
|
||||
num_channels = layer_config['NUM_CHANNELS']
|
||||
block = blocks_dict[layer_config['BLOCK']]
|
||||
fuse_method = layer_config['FUSE_METHOD']
|
||||
|
||||
modules = []
|
||||
for i in range(num_modules):
|
||||
# multi_scale_output is only used last module
|
||||
if not multi_scale_output and i == num_modules - 1:
|
||||
reset_multi_scale_output = False
|
||||
else:
|
||||
reset_multi_scale_output = True
|
||||
|
||||
modules.append(HighResolutionModule(
|
||||
num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, reset_multi_scale_output)
|
||||
)
|
||||
num_inchannels = modules[-1].get_num_inchannels()
|
||||
|
||||
return nn.Sequential(*modules), num_inchannels
|
||||
|
||||
def init_weights(self, pretrained='', ):
|
||||
logger.info('=> init weights from normal distribution')
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x = self.relu(x)
|
||||
x = self.layer1(x)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage2_cfg['NUM_BRANCHES']):
|
||||
if self.transition1[i] is not None:
|
||||
x_list.append(self.transition1[i](x))
|
||||
else:
|
||||
x_list.append(x)
|
||||
y_list = self.stage2(x_list)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage3_cfg['NUM_BRANCHES']):
|
||||
if self.transition2[i] is not None:
|
||||
x_list.append(self.transition2[i](y_list[-1]))
|
||||
else:
|
||||
x_list.append(y_list[i])
|
||||
y_list = self.stage3(x_list)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage4_cfg['NUM_BRANCHES']):
|
||||
if self.transition3[i] is not None:
|
||||
x_list.append(self.transition3[i](y_list[-1]))
|
||||
else:
|
||||
x_list.append(y_list[i])
|
||||
y_list = self.stage4(x_list)
|
||||
|
||||
# Classification Head
|
||||
y = self.incre_modules[0](y_list[0])
|
||||
for i in range(len(self.downsamp_modules)):
|
||||
y = self.incre_modules[i + 1](y_list[i + 1]) + self.downsamp_modules[i](y)
|
||||
|
||||
y = self.final_layer(y)
|
||||
|
||||
if torch._C._get_tracing_state():
|
||||
y = y.flatten(start_dim=2).mean(dim=2)
|
||||
else:
|
||||
y = F.avg_pool2d(y, kernel_size=y.size()[2:]).view(y.size(0), -1)
|
||||
|
||||
y = self.classifier(y)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def hrnet_w18_small(pretrained=True, **kwargs):
|
||||
default_cfg = default_cfgs['hrnet_w18_small']
|
||||
model = HighResolutionNet(cfg_cls_hrnet_w18_small, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
default_cfg,
|
||||
num_classes=kwargs.get('num_classes', 0),
|
||||
in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def hrnet_w18_small_v2(pretrained=True, **kwargs):
|
||||
default_cfg = default_cfgs['hrnet_w18_small_v2']
|
||||
model = HighResolutionNet(cfg_cls_hrnet_w18_small_v2, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
default_cfg,
|
||||
num_classes=kwargs.get('num_classes', 0),
|
||||
in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
@register_model
|
||||
def hrnet_w18(pretrained=True, **kwargs):
|
||||
default_cfg = default_cfgs['hrnet_w18']
|
||||
model = HighResolutionNet(cfg_cls_hrnet_w18, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
default_cfg,
|
||||
num_classes=kwargs.get('num_classes', 0),
|
||||
in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def hrnet_w30(pretrained=True, **kwargs):
|
||||
default_cfg = default_cfgs['hrnet_w30']
|
||||
model = HighResolutionNet(cfg_cls_hrnet_w30, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
default_cfg,
|
||||
num_classes=kwargs.get('num_classes', 0),
|
||||
in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
@register_model
|
||||
def hrnet_w32(pretrained=True, **kwargs):
|
||||
default_cfg = default_cfgs['hrnet_w32']
|
||||
model = HighResolutionNet(cfg_cls_hrnet_w32, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
default_cfg,
|
||||
num_classes=kwargs.get('num_classes', 0),
|
||||
in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
@register_model
|
||||
def hrnet_w40(pretrained=True, **kwargs):
|
||||
default_cfg = default_cfgs['hrnet_w40']
|
||||
model = HighResolutionNet(cfg_cls_hrnet_w40, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
default_cfg,
|
||||
num_classes=kwargs.get('num_classes', 0),
|
||||
in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def hrnet_w44(pretrained=True, **kwargs):
|
||||
default_cfg = default_cfgs['hrnet_w44']
|
||||
model = HighResolutionNet(cfg_cls_hrnet_w44, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
default_cfg,
|
||||
num_classes=kwargs.get('num_classes', 0),
|
||||
in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def hrnet_w48(pretrained=True, **kwargs):
|
||||
default_cfg = default_cfgs['hrnet_w48']
|
||||
model = HighResolutionNet(cfg_cls_hrnet_w48, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
default_cfg,
|
||||
num_classes=kwargs.get('num_classes', 0),
|
||||
in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
@ -0,0 +1,469 @@
|
||||
|
||||
""" MobileNet V3
|
||||
|
||||
A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.
|
||||
|
||||
Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .efficientnet_builder import *
|
||||
from .activations import HardSwish, hard_sigmoid
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from .conv2d_layers import select_conv2d
|
||||
from .feature_hooks import FeatureHooks
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
|
||||
__all__ = ['MobileNetV3']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'mobilenetv3_large_075': _cfg(url=''),
|
||||
'mobilenetv3_large_100': _cfg(url=''),
|
||||
'mobilenetv3_small_075': _cfg(url=''),
|
||||
'mobilenetv3_small_100': _cfg(url=''),
|
||||
'mobilenetv3_rw': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
|
||||
interpolation='bicubic'),
|
||||
'tf_mobilenetv3_large_075': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||
'tf_mobilenetv3_large_100': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||
'tf_mobilenetv3_large_minimal_100': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||
'tf_mobilenetv3_small_075': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||
'tf_mobilenetv3_small_100': _cfg(
|
||||
url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||
'tf_mobilenetv3_small_minimal_100': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||
}
|
||||
|
||||
_DEBUG = False
|
||||
|
||||
|
||||
class MobileNetV3(nn.Module):
|
||||
""" MobiletNet-V3
|
||||
|
||||
Based on my EfficientNet implementation and building blocks, this model utilizes the MobileNet-v3 specific
|
||||
'efficient head', where global pooling is done before the head convolution without a final batch-norm
|
||||
layer before the classifier.
|
||||
|
||||
Paper: https://arxiv.org/abs/1905.02244
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
|
||||
channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
global_pool='avg', weight_init='goog'):
|
||||
super(MobileNetV3, self).__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.num_features = num_features
|
||||
self.drop_rate = drop_rate
|
||||
self._in_chs = in_chans
|
||||
|
||||
# Stem
|
||||
stem_size = round_channels(stem_size, channel_multiplier)
|
||||
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
||||
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self._in_chs = stem_size
|
||||
|
||||
# Middle stages (IR/ER/DS Blocks)
|
||||
builder = EfficientNetBuilder(
|
||||
channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_connect_rate, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self.feature_info = builder.features
|
||||
self._in_chs = builder.in_chs
|
||||
|
||||
# Head + Pooling
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
|
||||
# Classifier
|
||||
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if weight_init == 'goog':
|
||||
efficientnet_init_goog(m)
|
||||
else:
|
||||
efficientnet_init_default(m)
|
||||
|
||||
def as_sequential(self):
|
||||
layers = [self.conv_stem, self.bn1, self.act1]
|
||||
layers.extend(self.blocks)
|
||||
layers.extend([self.global_pool, self.conv_head, self.act2])
|
||||
layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.classifier
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.num_classes = num_classes
|
||||
del self.classifier
|
||||
if num_classes:
|
||||
self.classifier = nn.Linear(
|
||||
self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
else:
|
||||
self.classifier = None
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
x = self.blocks(x)
|
||||
x = self.global_pool(x)
|
||||
x = self.conv_head(x)
|
||||
x = self.act2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = x.flatten(1)
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
return self.classifier(x)
|
||||
|
||||
|
||||
class MobileNetV3Features(nn.Module):
|
||||
""" MobileNetV3 Feature Extractor
|
||||
|
||||
A work-in-progress feature extraction module for MobileNet-V3 to use as a backbone for segmentation
|
||||
and object detection models.
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
|
||||
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
|
||||
act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., se_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'):
|
||||
super(MobileNetV3Features, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
|
||||
# TODO only create stages needed, currently all stages are created regardless of out_indices
|
||||
num_stages = max(out_indices) + 1
|
||||
|
||||
self.out_indices = out_indices
|
||||
self.drop_rate = drop_rate
|
||||
self._in_chs = in_chans
|
||||
|
||||
# Stem
|
||||
stem_size = round_channels(stem_size, channel_multiplier)
|
||||
self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
||||
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self._in_chs = stem_size
|
||||
|
||||
# Middle stages (IR/ER/DS Blocks)
|
||||
builder = EfficientNetBuilder(
|
||||
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_connect_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self.feature_info = builder.features # builder provides info about feature channels for each block
|
||||
self._in_chs = builder.in_chs
|
||||
|
||||
for m in self.modules():
|
||||
if weight_init == 'goog':
|
||||
efficientnet_init_goog(m)
|
||||
else:
|
||||
efficientnet_init_default(m)
|
||||
|
||||
if _DEBUG:
|
||||
for k, v in self.feature_info.items():
|
||||
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
|
||||
|
||||
# Register feature extraction hooks with FeatureHooks helper
|
||||
hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward'
|
||||
hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices]
|
||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||
|
||||
def feature_channels(self, idx=None):
|
||||
""" Feature Channel Shortcut
|
||||
Returns feature channel count for each output index if idx == None. If idx is an integer, will
|
||||
return feature channel count for that feature block index (independent of out_indices setting).
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self.feature_info[idx]['num_chs']
|
||||
return [self.feature_info[i]['num_chs'] for i in self.out_indices]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
self.blocks(x)
|
||||
return self.feature_hooks.get_output(x.device)
|
||||
|
||||
|
||||
def _create_model(model_kwargs, default_cfg, pretrained=False):
|
||||
if model_kwargs.pop('features_only', False):
|
||||
load_strict = False
|
||||
model_kwargs.pop('num_classes', 0)
|
||||
model_kwargs.pop('num_features', 0)
|
||||
model_kwargs.pop('head_conv', None)
|
||||
model_class = MobileNetV3Features
|
||||
else:
|
||||
load_strict = True
|
||||
model_class = MobileNetV3
|
||||
|
||||
model = model_class(**model_kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
default_cfg,
|
||||
num_classes=model_kwargs.get('num_classes', 0),
|
||||
in_chans=model_kwargs.get('in_chans', 3),
|
||||
strict=load_strict)
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
"""Creates a MobileNet-V3 model.
|
||||
|
||||
Ref impl: ?
|
||||
Paper: https://arxiv.org/abs/1905.02244
|
||||
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer.
|
||||
"""
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu
|
||||
# stage 1, 112x112 in
|
||||
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
|
||||
# stage 2, 56x56 in
|
||||
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
|
||||
# stage 3, 28x28 in
|
||||
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
|
||||
# stage 4, 14x14in
|
||||
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
|
||||
# stage 5, 14x14in
|
||||
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c960'], # hard-swish
|
||||
]
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
head_bias=False,
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
act_layer=HardSwish,
|
||||
se_kwargs=dict(gate_fn=hard_sigmoid, reduce_mid=True, divisor=1),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
"""Creates a MobileNet-V3 model.
|
||||
|
||||
Ref impl: ?
|
||||
Paper: https://arxiv.org/abs/1905.02244
|
||||
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer.
|
||||
"""
|
||||
if 'small' in variant:
|
||||
num_features = 1024
|
||||
if 'minimal' in variant:
|
||||
act_layer = nn.ReLU
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s2_e1_c16'],
|
||||
# stage 1, 56x56 in
|
||||
['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'],
|
||||
# stage 2, 28x28 in
|
||||
['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'],
|
||||
# stage 3, 14x14 in
|
||||
['ir_r2_k3_s1_e3_c48'],
|
||||
# stage 4, 14x14in
|
||||
['ir_r3_k3_s2_e6_c96'],
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c576'],
|
||||
]
|
||||
else:
|
||||
act_layer = HardSwish
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu
|
||||
# stage 1, 56x56 in
|
||||
['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu
|
||||
# stage 2, 28x28 in
|
||||
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish
|
||||
# stage 3, 14x14 in
|
||||
['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish
|
||||
# stage 4, 14x14in
|
||||
['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c576'], # hard-swish
|
||||
]
|
||||
else:
|
||||
num_features = 1280
|
||||
if 'minimal' in variant:
|
||||
act_layer = nn.ReLU
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s1_e1_c16'],
|
||||
# stage 1, 112x112 in
|
||||
['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'],
|
||||
# stage 2, 56x56 in
|
||||
['ir_r3_k3_s2_e3_c40'],
|
||||
# stage 3, 28x28 in
|
||||
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],
|
||||
# stage 4, 14x14in
|
||||
['ir_r2_k3_s1_e6_c112'],
|
||||
# stage 5, 14x14in
|
||||
['ir_r3_k3_s2_e6_c160'],
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c960'],
|
||||
]
|
||||
else:
|
||||
act_layer = HardSwish
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s1_e1_c16_nre'], # relu
|
||||
# stage 1, 112x112 in
|
||||
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
|
||||
# stage 2, 56x56 in
|
||||
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
|
||||
# stage 3, 28x28 in
|
||||
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
|
||||
# stage 4, 14x14in
|
||||
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
|
||||
# stage 5, 14x14in
|
||||
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c960'], # hard-swish
|
||||
]
|
||||
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
num_features=num_features,
|
||||
stem_size=16,
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
act_layer=act_layer,
|
||||
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_large_075(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_large_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_small_075(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_small_100(pretrained=False, **kwargs):
|
||||
print(kwargs)
|
||||
""" MobileNet V3 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_rw(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 """
|
||||
if pretrained:
|
||||
# pretrained model trained with non-default BN epsilon
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_mobilenetv3_large_075(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_mobilenetv3_large_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_mobilenetv3_small_075(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_mobilenetv3_small_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
|
||||
""" MobileNet V3 """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
Loading…
Reference in new issue