Make drop_connect rate scaling match official impl. Fixes #14

pull/16/head
Ross Wightman 5 years ago
parent 13c19e213d
commit 9d653b68a2

@ -318,7 +318,11 @@ class _BlockBuilder:
self.folded_bn = folded_bn self.folded_bn = folded_bn
self.padding_same = padding_same self.padding_same = padding_same
self.verbose = verbose self.verbose = verbose
# updated during build
self.in_chs = None self.in_chs = None
self.block_idx = 0
self.block_count = 0
def _round_channels(self, chs): def _round_channels(self, chs):
return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
@ -334,35 +338,40 @@ class _BlockBuilder:
# block act fn overrides the model default # block act fn overrides the model default
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
assert ba['act_fn'] is not None assert ba['act_fn'] is not None
if self.verbose:
logging.info(' Args: {}'.format(str(ba)))
# could replace this if with lambdas or functools binding if variety increases
if bt == 'ir': if bt == 'ir':
ba['drop_connect_rate'] = self.drop_connect_rate ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
ba['se_gate_fn'] = self.se_gate_fn ba['se_gate_fn'] = self.se_gate_fn
ba['se_reduce_mid'] = self.se_reduce_mid ba['se_reduce_mid'] = self.se_reduce_mid
if self.verbose:
logging.info(' InvertedResidual {}, Args: {}'.format(self.block_idx, str(ba)))
block = InvertedResidual(**ba) block = InvertedResidual(**ba)
elif bt == 'ds' or bt == 'dsa': elif bt == 'ds' or bt == 'dsa':
ba['drop_connect_rate'] = self.drop_connect_rate ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
if self.verbose:
logging.info(' DepthwiseSeparable {}, Args: {}'.format(self.block_idx, str(ba)))
block = DepthwiseSeparableConv(**ba) block = DepthwiseSeparableConv(**ba)
elif bt == 'cn': elif bt == 'cn':
if self.verbose:
logging.info(' ConvBnAct {}, Args: {}'.format(self.block_idx, str(ba)))
block = ConvBnAct(**ba) block = ConvBnAct(**ba)
else: else:
assert False, 'Uknkown block type (%s) while building model.' % bt assert False, 'Uknkown block type (%s) while building model.' % bt
self.in_chs = ba['out_chs'] # update in_chs for arg of next block self.in_chs = ba['out_chs'] # update in_chs for arg of next block
return block return block
def _make_stack(self, stack_args): def _make_stack(self, stack_args):
blocks = [] blocks = []
# each stack (stage) contains a list of block arguments # each stack (stage) contains a list of block arguments
for block_idx, ba in enumerate(stack_args): for i, ba in enumerate(stack_args):
if self.verbose: if self.verbose:
logging.info(' Block: {}'.format(block_idx)) logging.info(' Block: {}'.format(i))
if block_idx >= 1: if i >= 1:
# only the first block in any stack/stage can have a stride > 1 # only the first block in any stack can have a stride > 1
ba['stride'] = 1 ba['stride'] = 1
block = self._make_block(ba) block = self._make_block(ba)
blocks.append(block) blocks.append(block)
self.block_idx += 1 # incr global idx (across all stacks)
return nn.Sequential(*blocks) return nn.Sequential(*blocks)
def __call__(self, in_chs, block_args): def __call__(self, in_chs, block_args):
@ -377,6 +386,8 @@ class _BlockBuilder:
if self.verbose: if self.verbose:
logging.info('Building model trunk with %d stages...' % len(block_args)) logging.info('Building model trunk with %d stages...' % len(block_args))
self.in_chs = in_chs self.in_chs = in_chs
self.block_count = sum([len(x) for x in block_args])
self.block_idx = 0
blocks = [] blocks = []
# outer list of block_args defines the stacks ('stages' by some conventions) # outer list of block_args defines the stacks ('stages' by some conventions)
for stack_idx, stack in enumerate(block_args): for stack_idx, stack in enumerate(block_args):
@ -1404,6 +1415,7 @@ def efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B0 """ """ EfficientNet-B0 """
default_cfg = default_cfgs['efficientnet_b0'] default_cfg = default_cfgs['efficientnet_b0']
# NOTE for train, drop_rate should be 0.2 # NOTE for train, drop_rate should be 0.2
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet( model = _gen_efficientnet(
channel_multiplier=1.0, depth_multiplier=1.0, channel_multiplier=1.0, depth_multiplier=1.0,
num_classes=num_classes, in_chans=in_chans, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)
@ -1418,6 +1430,7 @@ def efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B1 """ """ EfficientNet-B1 """
default_cfg = default_cfgs['efficientnet_b1'] default_cfg = default_cfgs['efficientnet_b1']
# NOTE for train, drop_rate should be 0.2 # NOTE for train, drop_rate should be 0.2
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet( model = _gen_efficientnet(
channel_multiplier=1.0, depth_multiplier=1.1, channel_multiplier=1.0, depth_multiplier=1.1,
num_classes=num_classes, in_chans=in_chans, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)
@ -1432,6 +1445,7 @@ def efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B2 """ """ EfficientNet-B2 """
default_cfg = default_cfgs['efficientnet_b2'] default_cfg = default_cfgs['efficientnet_b2']
# NOTE for train, drop_rate should be 0.3 # NOTE for train, drop_rate should be 0.3
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet( model = _gen_efficientnet(
channel_multiplier=1.1, depth_multiplier=1.2, channel_multiplier=1.1, depth_multiplier=1.2,
num_classes=num_classes, in_chans=in_chans, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)
@ -1446,6 +1460,7 @@ def efficientnet_b3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B3 """ """ EfficientNet-B3 """
default_cfg = default_cfgs['efficientnet_b3'] default_cfg = default_cfgs['efficientnet_b3']
# NOTE for train, drop_rate should be 0.3 # NOTE for train, drop_rate should be 0.3
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet( model = _gen_efficientnet(
channel_multiplier=1.2, depth_multiplier=1.4, channel_multiplier=1.2, depth_multiplier=1.4,
num_classes=num_classes, in_chans=in_chans, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)
@ -1460,6 +1475,7 @@ def efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B4 """ """ EfficientNet-B4 """
default_cfg = default_cfgs['efficientnet_b4'] default_cfg = default_cfgs['efficientnet_b4']
# NOTE for train, drop_rate should be 0.4 # NOTE for train, drop_rate should be 0.4
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
model = _gen_efficientnet( model = _gen_efficientnet(
channel_multiplier=1.4, depth_multiplier=1.8, channel_multiplier=1.4, depth_multiplier=1.8,
num_classes=num_classes, in_chans=in_chans, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)
@ -1473,6 +1489,7 @@ def efficientnet_b4(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
""" EfficientNet-B5 """ """ EfficientNet-B5 """
# NOTE for train, drop_rate should be 0.4 # NOTE for train, drop_rate should be 0.4
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
default_cfg = default_cfgs['efficientnet_b5'] default_cfg = default_cfgs['efficientnet_b5']
model = _gen_efficientnet( model = _gen_efficientnet(
channel_multiplier=1.6, depth_multiplier=2.2, channel_multiplier=1.6, depth_multiplier=2.2,

Loading…
Cancel
Save