Unify drop connect vs drop path under 'drop path' name, switch all EfficientNet/MobilenetV3 refs to 'drop_path'. Update factory to handle new drop args.

pull/88/head
Ross Wightman 4 years ago
parent f1d5f8a6c4
commit 43225d110c

@ -253,7 +253,7 @@ class EfficientNet(nn.Module):
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32,
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
super(EfficientNet, self).__init__()
norm_kwargs = norm_kwargs or {}
@ -273,7 +273,7 @@ class EfficientNet(nn.Module):
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_connect_rate, verbose=_DEBUG)
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = builder.features
self._in_chs = builder.in_chs
@ -333,7 +333,7 @@ class EfficientNetFeatures(nn.Module):
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(EfficientNetFeatures, self).__init__()
norm_kwargs = norm_kwargs or {}
@ -355,7 +355,7 @@ class EfficientNetFeatures(nn.Module):
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_connect_rate, feature_location=feature_location, verbose=_DEBUG)
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = builder.features # builder provides info about feature channels for each block
self._in_chs = builder.in_chs
@ -875,7 +875,7 @@ def spnasnet_100(pretrained=False, **kwargs):
@register_model
def efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0 """
# NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
@ -884,7 +884,7 @@ def efficientnet_b0(pretrained=False, **kwargs):
@register_model
def efficientnet_b1(pretrained=False, **kwargs):
""" EfficientNet-B1 """
# NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
@ -893,7 +893,7 @@ def efficientnet_b1(pretrained=False, **kwargs):
@register_model
def efficientnet_b2(pretrained=False, **kwargs):
""" EfficientNet-B2 """
# NOTE for train, drop_rate should be 0.3, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
@ -902,7 +902,7 @@ def efficientnet_b2(pretrained=False, **kwargs):
@register_model
def efficientnet_b2a(pretrained=False, **kwargs):
""" EfficientNet-B2 @ 288x288 w/ 1.0 test crop"""
# NOTE for train, drop_rate should be 0.3, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_b2a', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
@ -911,7 +911,7 @@ def efficientnet_b2a(pretrained=False, **kwargs):
@register_model
def efficientnet_b3(pretrained=False, **kwargs):
""" EfficientNet-B3 """
# NOTE for train, drop_rate should be 0.3, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
@ -920,7 +920,7 @@ def efficientnet_b3(pretrained=False, **kwargs):
@register_model
def efficientnet_b3a(pretrained=False, **kwargs):
""" EfficientNet-B3 @ 320x320 w/ 1.0 test crop-pct """
# NOTE for train, drop_rate should be 0.3, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_b3a', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
@ -929,7 +929,7 @@ def efficientnet_b3a(pretrained=False, **kwargs):
@register_model
def efficientnet_b4(pretrained=False, **kwargs):
""" EfficientNet-B4 """
# NOTE for train, drop_rate should be 0.4, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
return model
@ -938,7 +938,7 @@ def efficientnet_b4(pretrained=False, **kwargs):
@register_model
def efficientnet_b5(pretrained=False, **kwargs):
""" EfficientNet-B5 """
# NOTE for train, drop_rate should be 0.4, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
return model
@ -947,7 +947,7 @@ def efficientnet_b5(pretrained=False, **kwargs):
@register_model
def efficientnet_b6(pretrained=False, **kwargs):
""" EfficientNet-B6 """
# NOTE for train, drop_rate should be 0.5, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
return model
@ -956,7 +956,7 @@ def efficientnet_b6(pretrained=False, **kwargs):
@register_model
def efficientnet_b7(pretrained=False, **kwargs):
""" EfficientNet-B7 """
# NOTE for train, drop_rate should be 0.5, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
return model
@ -965,7 +965,7 @@ def efficientnet_b7(pretrained=False, **kwargs):
@register_model
def efficientnet_b8(pretrained=False, **kwargs):
""" EfficientNet-B8 """
# NOTE for train, drop_rate should be 0.5, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
return model
@ -974,7 +974,7 @@ def efficientnet_b8(pretrained=False, **kwargs):
@register_model
def efficientnet_l2(pretrained=False, **kwargs):
""" EfficientNet-L2."""
# NOTE for train, drop_rate should be 0.5, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
return model
@ -1007,7 +1007,7 @@ def efficientnet_el(pretrained=False, **kwargs):
@register_model
def efficientnet_cc_b0_4e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B0 w/ 8 Experts """
# NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
model = _gen_efficientnet_condconv(
'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
@ -1016,7 +1016,7 @@ def efficientnet_cc_b0_4e(pretrained=False, **kwargs):
@register_model
def efficientnet_cc_b0_8e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B0 w/ 8 Experts """
# NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
model = _gen_efficientnet_condconv(
'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
pretrained=pretrained, **kwargs)
@ -1025,7 +1025,7 @@ def efficientnet_cc_b0_8e(pretrained=False, **kwargs):
@register_model
def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B1 w/ 8 Experts """
# NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
model = _gen_efficientnet_condconv(
'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
pretrained=pretrained, **kwargs)
@ -1355,7 +1355,7 @@ def tf_efficientnet_el(pretrained=False, **kwargs):
@register_model
def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B0 w/ 4 Experts. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_condconv(
@ -1366,7 +1366,7 @@ def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
@register_model
def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B0 w/ 8 Experts. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_condconv(
@ -1377,7 +1377,7 @@ def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
@register_model
def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
""" EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """
# NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet_condconv(

@ -2,7 +2,7 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from .layers.activations import sigmoid
from .layers import create_conv2d
from .layers import create_conv2d, drop_path
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
@ -69,19 +69,6 @@ def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
return make_divisible(channels, divisor, channel_min)
def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.):
"""Apply drop connect."""
if not training:
return inputs
keep_prob = 1 - drop_connect_rate
random_tensor = keep_prob + torch.rand(
(inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
random_tensor.floor_() # binarize
output = inputs.div(keep_prob) * random_tensor
return output
class ChannelShuffle(nn.Module):
# FIXME haven't used yet
def __init__(self, groups):
@ -154,13 +141,13 @@ class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
norm_kwargs = norm_kwargs or {}
has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv
self.drop_connect_rate = drop_connect_rate
self.drop_path_rate = drop_path_rate
self.conv_dw = create_conv2d(
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
@ -200,8 +187,8 @@ class DepthwiseSeparableConv(nn.Module):
x = self.act2(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
return x
@ -213,14 +200,14 @@ class InvertedResidual(nn.Module):
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
conv_kwargs=None, drop_connect_rate=0.):
conv_kwargs=None, drop_path_rate=0.):
super(InvertedResidual, self).__init__()
norm_kwargs = norm_kwargs or {}
conv_kwargs = conv_kwargs or {}
mid_chs = make_divisible(in_chs * exp_ratio)
has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_connect_rate = drop_connect_rate
self.drop_path_rate = drop_path_rate
# Point-wise expansion
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
@ -278,8 +265,8 @@ class InvertedResidual(nn.Module):
x = self.bn3(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
return x
@ -292,7 +279,7 @@ class CondConvResidual(InvertedResidual):
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
num_experts=0, drop_connect_rate=0.):
num_experts=0, drop_path_rate=0.):
self.num_experts = num_experts
conv_kwargs = dict(num_experts=self.num_experts)
@ -302,7 +289,7 @@ class CondConvResidual(InvertedResidual):
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
drop_connect_rate=drop_connect_rate)
drop_path_rate=drop_path_rate)
self.routing_fn = nn.Linear(in_chs, self.num_experts)
@ -332,8 +319,8 @@ class CondConvResidual(InvertedResidual):
x = self.bn3(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
return x
@ -344,7 +331,7 @@ class EdgeResidual(nn.Module):
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
drop_connect_rate=0.):
drop_path_rate=0.):
super(EdgeResidual, self).__init__()
norm_kwargs = norm_kwargs or {}
if fake_in_chs > 0:
@ -353,7 +340,7 @@ class EdgeResidual(nn.Module):
mid_chs = make_divisible(in_chs * exp_ratio)
has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_connect_rate = drop_connect_rate
self.drop_path_rate = drop_path_rate
# Expansion convolution
self.conv_exp = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
@ -400,8 +387,8 @@ class EdgeResidual(nn.Module):
x = self.bn2(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
return x

@ -202,7 +202,7 @@ class EfficientNetBuilder:
"""
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', act_layer=None, se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0., feature_location='',
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., feature_location='',
verbose=False):
self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor
@ -213,7 +213,7 @@ class EfficientNetBuilder:
self.se_kwargs = se_kwargs
self.norm_layer = norm_layer
self.norm_kwargs = norm_kwargs
self.drop_connect_rate = drop_connect_rate
self.drop_path_rate = drop_path_rate
self.feature_location = feature_location
assert feature_location in ('pre_pwl', 'post_exp', '')
self.verbose = verbose
@ -226,7 +226,7 @@ class EfficientNetBuilder:
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
def _make_block(self, ba, block_idx, block_count):
drop_connect_rate = self.drop_connect_rate * block_idx / block_count
drop_path_rate = self.drop_path_rate * block_idx / block_count
bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs
ba['out_chs'] = self._round_channels(ba['out_chs'])
@ -240,7 +240,7 @@ class EfficientNetBuilder:
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
assert ba['act_layer'] is not None
if bt == 'ir':
ba['drop_connect_rate'] = drop_connect_rate
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
if self.verbose:
logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)))
@ -249,13 +249,13 @@ class EfficientNetBuilder:
else:
block = InvertedResidual(**ba)
elif bt == 'ds' or bt == 'dsa':
ba['drop_connect_rate'] = drop_connect_rate
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
if self.verbose:
logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)))
block = DepthwiseSeparableConv(**ba)
elif bt == 'er':
ba['drop_connect_rate'] = drop_connect_rate
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
if self.verbose:
logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)))

@ -31,7 +31,21 @@ def create_model(
kwargs.pop('bn_tf', None)
kwargs.pop('bn_momentum', None)
kwargs.pop('bn_eps', None)
kwargs.pop('drop_connect_rate', None)
# Parameters that aren't supported by all models should default to None in command line args,
# remove them if they are present and not set so that non-supporting models don't break.
if kwargs.get('drop_block_rate', None) is None:
kwargs.pop('drop_block_rate', None)
# handle backwards compat with drop_connect -> drop_path change
drop_connect_rate = kwargs.pop('drop_connect_rate', None)
if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:
print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'."
" Setting drop_path to %f." % drop_connect_rate)
kwargs['drop_path_rate'] = drop_connect_rate
if kwargs.get('drop_path_rate', None) is None:
kwargs.pop('drop_path_rate', None)
if is_model(model_name):
create_fn = model_entrypoint(model_name)

@ -71,7 +71,7 @@ class MobileNetV3(nn.Module):
"""
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
super(MobileNetV3, self).__init__()
@ -90,7 +90,7 @@ class MobileNetV3(nn.Module):
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_connect_rate, verbose=_DEBUG)
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = builder.features
self._in_chs = builder.in_chs
@ -148,7 +148,7 @@ class MobileNetV3Features(nn.Module):
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., se_kwargs=None,
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(MobileNetV3Features, self).__init__()
norm_kwargs = norm_kwargs or {}
@ -170,7 +170,7 @@ class MobileNetV3Features(nn.Module):
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_connect_rate, feature_location=feature_location, verbose=_DEBUG)
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = builder.features # builder provides info about feature channels for each block
self._in_chs = builder.in_chs

@ -81,10 +81,14 @@ parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
help='ratio of validation batch size to training batch size (default: 1)')
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP',
help='Drop connect rate (default: 0.)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
parser.add_argument('--jsd', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
# Optimizer parameters
@ -242,7 +246,9 @@ def main():
pretrained=args.pretrained,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_connect_rate=args.drop_connect,
drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
global_pool=args.gp,
bn_tf=args.bn_tf,
bn_momentum=args.bn_momentum,

Loading…
Cancel
Save