|
|
@ -276,12 +276,13 @@ class _BlockBuilder:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
|
|
|
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
|
|
|
act_fn=None, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
|
|
|
|
drop_connect_rate=0., act_fn=None, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
|
|
|
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
|
|
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
|
|
|
folded_bn=False, padding_same=False, verbose=False):
|
|
|
|
folded_bn=False, padding_same=False, verbose=False):
|
|
|
|
self.channel_multiplier = channel_multiplier
|
|
|
|
self.channel_multiplier = channel_multiplier
|
|
|
|
self.channel_divisor = channel_divisor
|
|
|
|
self.channel_divisor = channel_divisor
|
|
|
|
self.channel_min = channel_min
|
|
|
|
self.channel_min = channel_min
|
|
|
|
|
|
|
|
self.drop_connect_rate = drop_connect_rate
|
|
|
|
self.act_fn = act_fn
|
|
|
|
self.act_fn = act_fn
|
|
|
|
self.se_gate_fn = se_gate_fn
|
|
|
|
self.se_gate_fn = se_gate_fn
|
|
|
|
self.se_reduce_mid = se_reduce_mid
|
|
|
|
self.se_reduce_mid = se_reduce_mid
|
|
|
@ -310,10 +311,12 @@ class _BlockBuilder:
|
|
|
|
print('args:', ba)
|
|
|
|
print('args:', ba)
|
|
|
|
# could replace this if with lambdas or functools binding if variety increases
|
|
|
|
# could replace this if with lambdas or functools binding if variety increases
|
|
|
|
if bt == 'ir':
|
|
|
|
if bt == 'ir':
|
|
|
|
|
|
|
|
ba['drop_connect_rate'] = self.drop_connect_rate
|
|
|
|
ba['se_gate_fn'] = self.se_gate_fn
|
|
|
|
ba['se_gate_fn'] = self.se_gate_fn
|
|
|
|
ba['se_reduce_mid'] = self.se_reduce_mid
|
|
|
|
ba['se_reduce_mid'] = self.se_reduce_mid
|
|
|
|
block = InvertedResidual(**ba)
|
|
|
|
block = InvertedResidual(**ba)
|
|
|
|
elif bt == 'ds' or bt == 'dsa':
|
|
|
|
elif bt == 'ds' or bt == 'dsa':
|
|
|
|
|
|
|
|
ba['drop_connect_rate'] = self.drop_connect_rate
|
|
|
|
block = DepthwiseSeparableConv(**ba)
|
|
|
|
block = DepthwiseSeparableConv(**ba)
|
|
|
|
elif bt == 'ca':
|
|
|
|
elif bt == 'ca':
|
|
|
|
block = CascadeConv(**ba)
|
|
|
|
block = CascadeConv(**ba)
|
|
|
@ -402,6 +405,19 @@ def hard_sigmoid(x):
|
|
|
|
return F.relu6(x + 3.) / 6.
|
|
|
|
return F.relu6(x + 3.) / 6.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def drop_connect(inputs, training=False, drop_connect_rate=0.):
|
|
|
|
|
|
|
|
"""Apply drop connect."""
|
|
|
|
|
|
|
|
if not training:
|
|
|
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
keep_prob = 1 - drop_connect_rate
|
|
|
|
|
|
|
|
random_tensor = keep_prob + torch.rand(
|
|
|
|
|
|
|
|
(inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
|
|
|
|
|
|
|
|
random_tensor.floor_() # binarize
|
|
|
|
|
|
|
|
output = inputs.div(keep_prob) * random_tensor
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChannelShuffle(nn.Module):
|
|
|
|
class ChannelShuffle(nn.Module):
|
|
|
|
# FIXME haven't used yet
|
|
|
|
# FIXME haven't used yet
|
|
|
|
def __init__(self, groups):
|
|
|
|
def __init__(self, groups):
|
|
|
@ -474,13 +490,14 @@ class DepthwiseSeparableConv(nn.Module):
|
|
|
|
stride=1, act_fn=F.relu, noskip=False, pw_act=False,
|
|
|
|
stride=1, act_fn=F.relu, noskip=False, pw_act=False,
|
|
|
|
se_ratio=0., se_gate_fn=torch.sigmoid,
|
|
|
|
se_ratio=0., se_gate_fn=torch.sigmoid,
|
|
|
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
|
|
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
|
|
|
folded_bn=False, padding_same=False):
|
|
|
|
folded_bn=False, padding_same=False, drop_connect_rate=0.):
|
|
|
|
super(DepthwiseSeparableConv, self).__init__()
|
|
|
|
super(DepthwiseSeparableConv, self).__init__()
|
|
|
|
assert stride in [1, 2]
|
|
|
|
assert stride in [1, 2]
|
|
|
|
self.has_se = se_ratio is not None and se_ratio > 0.
|
|
|
|
self.has_se = se_ratio is not None and se_ratio > 0.
|
|
|
|
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
|
|
|
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
|
|
|
self.has_pw_act = pw_act # activation after point-wise conv
|
|
|
|
self.has_pw_act = pw_act # activation after point-wise conv
|
|
|
|
self.act_fn = act_fn
|
|
|
|
self.act_fn = act_fn
|
|
|
|
|
|
|
|
self.drop_connect_rate = drop_connect_rate
|
|
|
|
dw_padding = _padding_arg(kernel_size // 2, padding_same)
|
|
|
|
dw_padding = _padding_arg(kernel_size // 2, padding_same)
|
|
|
|
pw_padding = _padding_arg(0, padding_same)
|
|
|
|
pw_padding = _padding_arg(0, padding_same)
|
|
|
|
|
|
|
|
|
|
|
@ -515,7 +532,9 @@ class DepthwiseSeparableConv(nn.Module):
|
|
|
|
x = self.act_fn(x)
|
|
|
|
x = self.act_fn(x)
|
|
|
|
|
|
|
|
|
|
|
|
if self.has_residual:
|
|
|
|
if self.has_residual:
|
|
|
|
x += residual # FIXME add drop-connect
|
|
|
|
if self.drop_connect_rate > 0.:
|
|
|
|
|
|
|
|
x = drop_connect(x, self.training, self.drop_connect_rate)
|
|
|
|
|
|
|
|
x += residual
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -557,12 +576,13 @@ class InvertedResidual(nn.Module):
|
|
|
|
se_ratio=0., se_reduce_mid=False, se_gate_fn=torch.sigmoid,
|
|
|
|
se_ratio=0., se_reduce_mid=False, se_gate_fn=torch.sigmoid,
|
|
|
|
shuffle_type=None, pw_group=1,
|
|
|
|
shuffle_type=None, pw_group=1,
|
|
|
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
|
|
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
|
|
|
folded_bn=False, padding_same=False):
|
|
|
|
folded_bn=False, padding_same=False, drop_connect_rate=0.):
|
|
|
|
super(InvertedResidual, self).__init__()
|
|
|
|
super(InvertedResidual, self).__init__()
|
|
|
|
mid_chs = int(in_chs * exp_ratio)
|
|
|
|
mid_chs = int(in_chs * exp_ratio)
|
|
|
|
self.has_se = se_ratio is not None and se_ratio > 0.
|
|
|
|
self.has_se = se_ratio is not None and se_ratio > 0.
|
|
|
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
|
|
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
|
|
|
self.act_fn = act_fn
|
|
|
|
self.act_fn = act_fn
|
|
|
|
|
|
|
|
self.drop_connect_rate = drop_connect_rate
|
|
|
|
dw_padding = _padding_arg(kernel_size // 2, padding_same)
|
|
|
|
dw_padding = _padding_arg(kernel_size // 2, padding_same)
|
|
|
|
pw_padding = _padding_arg(0, padding_same)
|
|
|
|
pw_padding = _padding_arg(0, padding_same)
|
|
|
|
|
|
|
|
|
|
|
@ -619,7 +639,9 @@ class InvertedResidual(nn.Module):
|
|
|
|
x = self.bn3(x)
|
|
|
|
x = self.bn3(x)
|
|
|
|
|
|
|
|
|
|
|
|
if self.has_residual:
|
|
|
|
if self.has_residual:
|
|
|
|
x += residual # FIXME add drop-connect
|
|
|
|
if self.drop_connect_rate > 0.:
|
|
|
|
|
|
|
|
x = drop_connect(x, self.training, self.drop_connect_rate)
|
|
|
|
|
|
|
|
x += residual
|
|
|
|
|
|
|
|
|
|
|
|
# NOTE maskrcnn_benchmark building blocks have an SE module defined here for some variants
|
|
|
|
# NOTE maskrcnn_benchmark building blocks have an SE module defined here for some variants
|
|
|
|
|
|
|
|
|
|
|
@ -643,12 +665,14 @@ class GenMobileNet(nn.Module):
|
|
|
|
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
|
|
|
|
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
|
|
|
|
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
|
|
|
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
|
|
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
|
|
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
|
|
|
drop_rate=0., act_fn=F.relu, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
|
|
|
|
drop_rate=0., drop_connect_rate=0., act_fn=F.relu,
|
|
|
|
|
|
|
|
se_gate_fn=torch.sigmoid, se_reduce_mid=False,
|
|
|
|
global_pool='avg', head_conv='default', weight_init='goog',
|
|
|
|
global_pool='avg', head_conv='default', weight_init='goog',
|
|
|
|
folded_bn=False, padding_same=False):
|
|
|
|
folded_bn=False, padding_same=False,):
|
|
|
|
super(GenMobileNet, self).__init__()
|
|
|
|
super(GenMobileNet, self).__init__()
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
|
|
|
|
self.drop_connect_rate = drop_connect_rate
|
|
|
|
self.act_fn = act_fn
|
|
|
|
self.act_fn = act_fn
|
|
|
|
self.num_features = num_features
|
|
|
|
self.num_features = num_features
|
|
|
|
|
|
|
|
|
|
|
@ -661,7 +685,7 @@ class GenMobileNet(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
builder = _BlockBuilder(
|
|
|
|
builder = _BlockBuilder(
|
|
|
|
channel_multiplier, channel_divisor, channel_min,
|
|
|
|
channel_multiplier, channel_divisor, channel_min,
|
|
|
|
act_fn, se_gate_fn, se_reduce_mid,
|
|
|
|
drop_connect_rate, act_fn, se_gate_fn, se_reduce_mid,
|
|
|
|
bn_momentum, bn_eps, folded_bn, padding_same, verbose=_DEBUG)
|
|
|
|
bn_momentum, bn_eps, folded_bn, padding_same, verbose=_DEBUG)
|
|
|
|
self.blocks = nn.Sequential(*builder(in_chs, block_args))
|
|
|
|
self.blocks = nn.Sequential(*builder(in_chs, block_args))
|
|
|
|
in_chs = builder.in_chs
|
|
|
|
in_chs = builder.in_chs
|
|
|
@ -1090,7 +1114,7 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
|
|
|
|
def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
|
|
|
|
"""Creates a MobileNet-V3 model.
|
|
|
|
"""Creates an EfficientNet model.
|
|
|
|
|
|
|
|
|
|
|
|
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
|
|
|
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
|
|
|
Paper: https://arxiv.org/abs/1905.11946
|
|
|
|
Paper: https://arxiv.org/abs/1905.11946
|
|
|
@ -1347,7 +1371,7 @@ def spnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
|
|
def efficientnet_b0(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
|
|
def efficientnet_b0(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
|
|
""" EfficientNet """
|
|
|
|
""" EfficientNet """
|
|
|
|
default_cfg = default_cfgs['efficientnet_b0']
|
|
|
|
default_cfg = default_cfgs['efficientnet_b0']
|
|
|
|
# NOTE dropout should be 0.2 for train
|
|
|
|
# NOTE for train, drop_rate should be 0.2
|
|
|
|
model = _gen_efficientnet(
|
|
|
|
model = _gen_efficientnet(
|
|
|
|
channel_multiplier=1.0, depth_multiplier=1.0,
|
|
|
|
channel_multiplier=1.0, depth_multiplier=1.0,
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
@ -1360,7 +1384,7 @@ def efficientnet_b0(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
|
|
def efficientnet_b1(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
|
|
def efficientnet_b1(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
|
|
""" EfficientNet """
|
|
|
|
""" EfficientNet """
|
|
|
|
default_cfg = default_cfgs['efficientnet_b1']
|
|
|
|
default_cfg = default_cfgs['efficientnet_b1']
|
|
|
|
# NOTE dropout should be 0.2 for train
|
|
|
|
# NOTE for train, drop_rate should be 0.2
|
|
|
|
model = _gen_efficientnet(
|
|
|
|
model = _gen_efficientnet(
|
|
|
|
channel_multiplier=1.0, depth_multiplier=1.1,
|
|
|
|
channel_multiplier=1.0, depth_multiplier=1.1,
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
@ -1373,7 +1397,7 @@ def efficientnet_b1(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
|
|
def efficientnet_b2(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
|
|
def efficientnet_b2(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
|
|
""" EfficientNet """
|
|
|
|
""" EfficientNet """
|
|
|
|
default_cfg = default_cfgs['efficientnet_b2']
|
|
|
|
default_cfg = default_cfgs['efficientnet_b2']
|
|
|
|
# NOTE dropout should be 0.3 for train
|
|
|
|
# NOTE for train, drop_rate should be 0.3
|
|
|
|
model = _gen_efficientnet(
|
|
|
|
model = _gen_efficientnet(
|
|
|
|
channel_multiplier=1.1, depth_multiplier=1.2,
|
|
|
|
channel_multiplier=1.1, depth_multiplier=1.2,
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
@ -1386,7 +1410,7 @@ def efficientnet_b2(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
|
|
def efficientnet_b3(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
|
|
def efficientnet_b3(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
|
|
""" EfficientNet """
|
|
|
|
""" EfficientNet """
|
|
|
|
default_cfg = default_cfgs['efficientnet_b3']
|
|
|
|
default_cfg = default_cfgs['efficientnet_b3']
|
|
|
|
# NOTE dropout should be 0.3 for train
|
|
|
|
# NOTE for train, drop_rate should be 0.3
|
|
|
|
model = _gen_efficientnet(
|
|
|
|
model = _gen_efficientnet(
|
|
|
|
channel_multiplier=1.2, depth_multiplier=1.4,
|
|
|
|
channel_multiplier=1.2, depth_multiplier=1.4,
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
@ -1399,7 +1423,7 @@ def efficientnet_b3(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
|
|
def efficientnet_b4(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
|
|
def efficientnet_b4(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|
|
|
""" EfficientNet """
|
|
|
|
""" EfficientNet """
|
|
|
|
default_cfg = default_cfgs['efficientnet_b4']
|
|
|
|
default_cfg = default_cfgs['efficientnet_b4']
|
|
|
|
# NOTE dropout should be 0.4 for train
|
|
|
|
# NOTE for train, drop_rate should be 0.4
|
|
|
|
model = _gen_efficientnet(
|
|
|
|
model = _gen_efficientnet(
|
|
|
|
channel_multiplier=1.4, depth_multiplier=1.8,
|
|
|
|
channel_multiplier=1.4, depth_multiplier=1.8,
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|