diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 3b503e52..da20e7a0 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -54,9 +54,8 @@ class Bottle2neck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=26, scale=4, use_se=False, - norm_layer=None, dilation=1, previous_dilation=1, **_): + act_layer=nn.ReLU, norm_layer=None, dilation=1, previous_dilation=1, **_): super(Bottle2neck, self).__init__() - assert dilation == 1 and previous_dilation == 1 # FIXME support dilation self.scale = scale self.is_first = stride > 1 or downsample is not None self.num_scales = max(1, scale - 1) @@ -71,18 +70,20 @@ class Bottle2neck(nn.Module): bns = [] for i in range(self.num_scales): convs.append(nn.Conv2d( - width, width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)) + width, width, kernel_size=3, stride=stride, padding=dilation, + dilation=dilation, groups=cardinality, bias=False)) bns.append(norm_layer(width)) self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList(bns) if self.is_first: + # FIXME this should probably have count_include_pad=False, but hurts original weights self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False) self.bn3 = norm_layer(outplanes) self.se = SEModule(outplanes, planes // 4) if use_se else None - self.relu = nn.ReLU(inplace=True) + self.relu = act_layer(inplace=True) self.downsample = downsample def forward(self, x): diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 67e1920b..abde4063 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -125,11 +125,12 @@ class SEModule(nn.Module): class BasicBlock(nn.Module): + __constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' @@ -141,12 +142,13 @@ class BasicBlock(nn.Module): inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) self.bn1 = norm_layer(first_planes) - self.relu = nn.ReLU(inplace=True) + self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d( first_planes, outplanes, kernel_size=3, padding=previous_dilation, dilation=previous_dilation, bias=False) self.bn2 = norm_layer(outplanes) self.se = SEModule(outplanes, planes // 4) if use_se else None + self.act2 = act_layer(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation @@ -156,7 +158,7 @@ class BasicBlock(nn.Module): out = self.conv1(x) out = self.bn1(out) - out = self.relu(out) + out = self.act1(out) out = self.conv2(out) out = self.bn2(out) @@ -167,17 +169,18 @@ class BasicBlock(nn.Module): residual = self.downsample(x) out += residual - out = self.relu(out) + out = self.act2(out) return out class Bottleneck(nn.Module): + __constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) @@ -186,14 +189,16 @@ class Bottleneck(nn.Module): self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d( first_planes, width, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, groups=cardinality, bias=False) self.bn2 = norm_layer(width) + self.act2 = act_layer(inplace=True) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.bn3 = norm_layer(outplanes) self.se = SEModule(outplanes, planes // 4) if use_se else None - self.relu = nn.ReLU(inplace=True) + self.act3 = act_layer(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation @@ -203,11 +208,11 @@ class Bottleneck(nn.Module): out = self.conv1(x) out = self.bn1(out) - out = self.relu(out) + out = self.act1(out) out = self.conv2(out) out = self.bn2(out) - out = self.relu(out) + out = self.act2(out) out = self.conv3(out) out = self.bn3(out) @@ -219,7 +224,7 @@ class Bottleneck(nn.Module): residual = self.downsample(x) out += residual - out = self.relu(out) + out = self.act3(out) return out @@ -284,9 +289,10 @@ class ResNet(nn.Module): Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets avg_down : bool, default False Whether to use average pooling for projection skip connection between stages/downsample. - dilated : bool, default False - Applying dilation strategy to pretrained ResNet yielding a stride-8 model, - typically used in Semantic Segmentation. + output_stride : int, default 32 + Set the output stride of the network, 32, 16, or 8. Typically used in segmentation. + act_layer : class, activation layer + norm_layer : class, normalization layer drop_rate : float, default 0. Dropout probability before classifier, for training global_pool : str, default 'avg' @@ -294,8 +300,8 @@ class ResNet(nn.Module): """ def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, cardinality=1, base_width=64, stem_width=64, stem_type='', - block_reduce_first=1, down_kernel_size=1, avg_down=False, dilated=False, - norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg', + block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg', zero_init_last_bn=True, block_args=None): block_args = block_args or dict() self.num_classes = num_classes @@ -305,9 +311,9 @@ class ResNet(nn.Module): self.base_width = base_width self.drop_rate = drop_rate self.expansion = block.expansion - self.dilated = dilated super(ResNet, self).__init__() + # Stem if deep_stem: stem_chs_1 = stem_chs_2 = stem_width if 'tiered' in stem_type: @@ -316,25 +322,37 @@ class ResNet(nn.Module): self.conv1 = nn.Sequential(*[ nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False), norm_layer(stem_chs_1), - nn.ReLU(inplace=True), + act_layer(inplace=True), nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False), norm_layer(stem_chs_2), - nn.ReLU(inplace=True), + act_layer(inplace=True), nn.Conv2d(stem_chs_2, self.inplanes, 3, stride=1, padding=1, bias=False)]) else: self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) - self.relu = nn.ReLU(inplace=True) + self.act1 = act_layer(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - stride_3_4 = 1 if self.dilated else 2 - dilation_3 = 2 if self.dilated else 1 - dilation_4 = 4 if self.dilated else 1 - largs = dict(use_se=use_se, reduce_first=block_reduce_first, norm_layer=norm_layer, - avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args) - self.layer1 = self._make_layer(block, 64, layers[0], stride=1, **largs) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, **largs) - self.layer3 = self._make_layer(block, 256, layers[2], stride=stride_3_4, dilation=dilation_3, **largs) - self.layer4 = self._make_layer(block, 512, layers[3], stride=stride_3_4, dilation=dilation_4, **largs) + + # Feature Blocks + channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4 + if output_stride == 16: + strides[3] = 1 + dilations[3] = 2 + elif output_stride == 8: + strides[2:4] = [1, 1] + dilations[2:4] = [2, 4] + else: + assert output_stride == 32 + llargs = list(zip(channels, layers, strides, dilations)) + lkwargs = dict( + use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, + avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args) + self.layer1 = self._make_layer(block, *llargs[0], **lkwargs) + self.layer2 = self._make_layer(block, *llargs[1], **lkwargs) + self.layer3 = self._make_layer(block, *llargs[2], **lkwargs) + self.layer4 = self._make_layer(block, *llargs[3], **lkwargs) + + # Head (Pooling and Classifier) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_features = 512 * block.expansion self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) @@ -352,7 +370,8 @@ class ResNet(nn.Module): nn.init.constant_(m.bias, 0.) def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1, - use_se=False, avg_down=False, down_kernel_size=1, norm_layer=nn.BatchNorm2d, **kwargs): + use_se=False, avg_down=False, down_kernel_size=1, **kwargs): + norm_layer = kwargs.get('norm_layer') downsample = None down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size if stride != 1 or self.inplanes != planes * block.expansion: @@ -370,15 +389,15 @@ class ResNet(nn.Module): downsample = nn.Sequential(*downsample_layers) first_dilation = 1 if dilation in (1, 2) else 2 - bargs = dict( + bkwargs = dict( cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, - use_se=use_se, norm_layer=norm_layer, **kwargs) + use_se=use_se, **kwargs) layers = [block( - self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bargs)] + self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bkwargs)] self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block( - self.inplanes, planes, dilation=dilation, previous_dilation=dilation, **bargs)) + self.inplanes, planes, dilation=dilation, previous_dilation=dilation, **bkwargs)) return nn.Sequential(*layers) @@ -394,7 +413,7 @@ class ResNet(nn.Module): def forward_features(self, x): x = self.conv1(x) x = self.bn1(x) - x = self.relu(x) + x = self.act1(x) x = self.maxpool(x) x = self.layer1(x)