diff --git a/timm/models/regnet.py b/timm/models/regnet.py index a93ab8a8..aab94cf0 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -18,7 +18,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule +from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule, DropPath from .registry import register_model @@ -195,7 +195,7 @@ class RegStage(nn.Module): """Stage (sequence of blocks w/ the same output shape).""" def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width, - block_fn=Bottleneck, se_ratio=0.): + block_fn=Bottleneck, se_ratio=0., drop_path_rate=None, drop_block=None): super(RegStage, self).__init__() block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args first_dilation = 1 if dilation in (1, 2) else 2 @@ -203,6 +203,7 @@ class RegStage(nn.Module): block_stride = stride if i == 0 else 1 block_in_chs = in_chs if i == 0 else out_chs block_dilation = first_dilation if i == 0 else dilation + drop_path = DropPath(drop_path_rate[i]) if drop_path_rate is not None else None if (block_in_chs != out_chs) or (block_stride != 1): proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation) else: @@ -212,7 +213,7 @@ class RegStage(nn.Module): self.add_module( name, block_fn( block_in_chs, out_chs, block_stride, block_dilation, bottle_ratio, group_width, se_ratio, - downsample=proj_block, **block_kwargs) + downsample=proj_block, drop_block=drop_block, drop_path=drop_path, **block_kwargs) ) def forward(self, x): @@ -229,7 +230,7 @@ class RegNet(nn.Module): """ def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0., - zero_init_last_bn=True): + drop_path_rate=0., zero_init_last_bn=True): super().__init__() # TODO add drop block, drop path, anti-aliasing, custom bn/act args self.num_classes = num_classes @@ -244,7 +245,7 @@ class RegNet(nn.Module): # Construct the stages prev_width = stem_width curr_stride = 2 - stage_params = self._get_stage_params(cfg, output_stride=output_stride) + stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate) se_ratio = cfg['se_ratio'] for i, stage_args in enumerate(stage_params): stage_name = "s{}".format(i + 1) @@ -272,7 +273,7 @@ class RegNet(nn.Module): if hasattr(m, 'zero_init_last_bn'): m.zero_init_last_bn() - def _get_stage_params(self, cfg, default_stride=2, output_stride=32): + def _get_stage_params(self, cfg, default_stride=2, output_stride=32, drop_path_rate=0.): # Generate RegNet ws per block w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth'] widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d) @@ -285,24 +286,26 @@ class RegNet(nn.Module): stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)] stage_strides = [] stage_dilations = [] - total_stride = 2 + net_stride = 2 dilation = 1 for _ in range(num_stages): - if total_stride >= output_stride: + if net_stride >= output_stride: dilation *= default_stride stride = 1 else: stride = default_stride - total_stride *= stride + net_stride *= stride stage_strides.append(stride) stage_dilations.append(dilation) + stage_dpr = np.split(np.linspace(0, drop_path_rate, d), np.cumsum(stage_depths[:-1])) # Adjust the compatibility of ws and gws stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups) - param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width'] + param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width', 'drop_path_rate'] stage_params = [ dict(zip(param_names, params)) for params in - zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups)] + zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups, + stage_dpr)] return stage_params def get_classifier(self):