diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py index c91b969e..06f89838 100644 --- a/timm/models/layers/drop.py +++ b/timm/models/layers/drop.py @@ -156,7 +156,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False): return output -class DropPath(nn.ModuleDict): +class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 1b87ed08..7c243297 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -205,14 +205,14 @@ class BasicBlock(nn.Module): first_planes = planes // reduce_first outplanes = planes * self.expansion first_dilation = first_dilation or dilation - use_aa = aa_layer is not None + use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation) self.conv1 = nn.Conv2d( inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation, dilation=first_dilation, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) - self.aa = aa_layer(channels=first_planes) if stride == 2 and use_aa else None + self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else None self.conv2 = nn.Conv2d( first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) @@ -272,7 +272,7 @@ class Bottleneck(nn.Module): first_planes = width // reduce_first outplanes = planes * self.expansion first_dilation = first_dilation or dilation - use_aa = aa_layer is not None + use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation) self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) self.bn1 = norm_layer(first_planes) @@ -283,7 +283,7 @@ class Bottleneck(nn.Module): padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) self.bn2 = norm_layer(width) self.act2 = act_layer(inplace=True) - self.aa = aa_layer(channels=width) if stride == 2 and use_aa else None + self.aa = aa_layer(channels=width, stride=stride) if use_aa else None self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.bn3 = norm_layer(outplanes) @@ -336,14 +336,6 @@ class Bottleneck(nn.Module): return x -def setup_drop_block(drop_block_rate=0.): - return [ - None, - None, - DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None, - DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None] - - def downsample_conv( in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): norm_layer = norm_layer or nn.BatchNorm2d @@ -375,6 +367,57 @@ def downsample_avg( ]) +def drop_blocks(drop_block_rate=0.): + return [ + None, None, + DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None, + DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None] + + +def make_blocks( + block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32, + down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs): + stages = [] + feature_info = [] + net_num_blocks = sum(block_repeats) + net_block_idx = 0 + net_stride = 4 + dilation = prev_dilation = 1 + for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))): + stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it + stride = 1 if stage_idx == 0 else 2 + if net_stride >= output_stride: + dilation *= stride + stride = 1 + else: + net_stride *= stride + + downsample = None + if stride != 1 or inplanes != planes * block_fn.expansion: + down_kwargs = dict( + in_channels=inplanes, out_channels=planes * block_fn.expansion, kernel_size=down_kernel_size, + stride=stride, dilation=dilation, first_dilation=prev_dilation, norm_layer=kwargs.get('norm_layer')) + downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs) + + block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs) + blocks = [] + for block_idx in range(num_blocks): + downsample = downsample if block_idx == 0 else None + stride = stride if block_idx == 0 else 1 + block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule + blocks.append(block_fn( + inplanes, planes, stride, downsample, first_dilation=prev_dilation, + drop_path=DropPath(block_dpr) if block_dpr > 0. else None, **block_kwargs)) + prev_dilation = dilation + inplanes = planes * block_fn.expansion + net_block_idx += 1 + + stages.append((stage_name, nn.Sequential(*blocks))) + feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name)) + + return stages, feature_info + + class ResNet(nn.Module): """ResNet / ResNeXt / SE-ResNeXt / SE-Net @@ -448,21 +491,18 @@ class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000, in_chans=3, cardinality=1, base_width=64, stem_width=64, stem_type='', - block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, + output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0., drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): block_args = block_args or dict() assert output_stride in (8, 16, 32) self.num_classes = num_classes - deep_stem = 'deep' in stem_type - self.inplanes = stem_width * 2 if deep_stem else 64 - self.cardinality = cardinality - self.base_width = base_width self.drop_rate = drop_rate - self.expansion = block.expansion super(ResNet, self).__init__() # Stem + deep_stem = 'deep' in stem_type + inplanes = stem_width * 2 if deep_stem else 64 if deep_stem: stem_chs_1 = stem_chs_2 = stem_width if 'tiered' in stem_type: @@ -475,43 +515,31 @@ class ResNet(nn.Module): nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False), norm_layer(stem_chs_2), act_layer(inplace=True), - nn.Conv2d(stem_chs_2, self.inplanes, 3, stride=1, padding=1, bias=False)]) + nn.Conv2d(stem_chs_2, inplanes, 3, stride=1, padding=1, bias=False)]) else: - self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) - self.bn1 = norm_layer(self.inplanes) + self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(inplanes) self.act1 = act_layer(inplace=True) - self.feature_info = [dict(num_chs=self.inplanes, reduction=2, module='act1')] + self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')] # Stem Pooling if aa_layer is not None: self.maxpool = nn.Sequential(*[ nn.MaxPool2d(kernel_size=3, stride=1, padding=1), - aa_layer(channels=self.inplanes, stride=2) - ]) + aa_layer(channels=inplanes, stride=2)]) else: self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Feature Blocks channels = [64, 128, 256, 512] - dp = DropPath(drop_path_rate) if drop_path_rate else None - db = setup_drop_block(drop_block_rate) - layer_kwargs = dict( - reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, - avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args) - total_stride = 4 - dilation = 1 - for i in range(4): - layer_name = f'layer{i + 1}' - stride = 2 if i > 0 else 1 - if total_stride >= output_stride: - dilation *= stride - stride = 1 - else: - total_stride *= stride - self.add_module(layer_name, self._make_layer( - block, channels[i], layers[i], stride, dilation, drop_block=db[i], **layer_kwargs)) - self.feature_info.append(dict( - num_chs=self.inplanes, reduction=total_stride, module=layer_name)) + stage_modules, stage_feature_info = make_blocks( + block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width, + output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down, + down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, + drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args) + for stage in stage_modules: + self.add_module(*stage) # layer1, layer2, etc + self.feature_info.extend(stage_feature_info) # Head (Pooling and Classifier) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) @@ -529,25 +557,6 @@ class ResNet(nn.Module): if hasattr(m, 'zero_init_last_bn'): m.zero_init_last_bn() - def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1, - avg_down=False, down_kernel_size=1, **kwargs): - downsample = None - first_dilation = 1 if dilation in (1, 2) else 2 - if stride != 1 or self.inplanes != planes * block.expansion: - downsample_args = dict( - in_channels=self.inplanes, out_channels=planes * block.expansion, kernel_size=down_kernel_size, - stride=stride, dilation=dilation, first_dilation=first_dilation, norm_layer=kwargs.get('norm_layer')) - downsample = downsample_avg(**downsample_args) if avg_down else downsample_conv(**downsample_args) - - block_kwargs = dict( - cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, - dilation=dilation, **kwargs) - layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)] - self.inplanes = planes * block.expansion - layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)] - - return nn.Sequential(*layers) - def get_classifier(self): return self.fc