diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index f012c3cf..33450483 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -15,3 +15,4 @@ from .adaptive_avgmax_pool import \ from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from .blurpool import BlurPool2d diff --git a/timm/models/layers/blurpool.py b/timm/models/layers/blurpool.py index 96937114..0b37a90c 100644 --- a/timm/models/layers/blurpool.py +++ b/timm/models/layers/blurpool.py @@ -17,7 +17,7 @@ class BlurPool2d(nn.Module): Corresponds to the Downsample class, which does blurring and subsampling Args: channels = Number of input channels - blur_filter_size (int): binomial filter size for blurring. currently supports 3(default) and 5. + blur_filter_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. stride (int): downsampling filter stride Shape: Returns: diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 0013cbe0..057eca6c 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -12,7 +12,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn +from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -104,6 +104,8 @@ default_cfgs = { interpolation='bicubic'), 'ecaresnet18': _cfg(), 'ecaresnet50': _cfg(), + 'resnetblur18': _cfg(), + 'resnetblur50': _cfg() } @@ -117,7 +119,7 @@ class BasicBlock(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - attn_layer=None, drop_block=None, drop_path=None): + attn_layer=None, drop_block=None, drop_path=None, blur=False): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' @@ -125,10 +127,19 @@ class BasicBlock(nn.Module): first_planes = planes // reduce_first outplanes = planes * self.expansion first_dilation = first_dilation or dilation + self.blur = blur - self.conv1 = nn.Conv2d( + if blur and stride==2: + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=1, padding=first_dilation, + dilation=first_dilation, bias=False) + self.blurpool=BlurPool2d(channels=first_planes) + else: + self.conv1 = nn.Conv2d( inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation, dilation=first_dilation, bias=False) + self.blurpool = None + self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d( @@ -154,7 +165,11 @@ class BasicBlock(nn.Module): x = self.bn1(x) if self.drop_block is not None: x = self.drop_block(x) - x = self.act1(x) + if self.blurpool is not None: + x = self.act1(x) + x = self.blurpool(x) + else: + x = self.act1(x) x = self.conv2(x) x = self.bn2(x) @@ -181,20 +196,30 @@ class Bottleneck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - attn_layer=None, drop_block=None, drop_path=None): + attn_layer=None, drop_block=None, drop_path=None, blur=False): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first outplanes = planes * self.expansion first_dilation = first_dilation or dilation + self.blur = blur self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) - self.conv2 = nn.Conv2d( - first_planes, width, kernel_size=3, stride=stride, - padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) + + if blur and stride==2: + self.conv2 = nn.Conv2d( + first_planes, width, kernel_size=3, stride=1, + padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) + self.blurpool = BlurPool2d(channels=width) + else: + self.conv2 = nn.Conv2d( + first_planes, width, kernel_size=3, stride=stride, + padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) + self.blurpool = None + self.bn2 = norm_layer(width) self.act2 = act_layer(inplace=True) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) @@ -345,12 +370,19 @@ class ResNet(nn.Module): Dropout probability before classifier, for training global_pool : str, default 'avg' Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' + blur : str, default '' + Location of Blurring: + * '', default - Not applied + * 'max' - only stem layer MaxPool will be blurred + * 'strided' - only strided convolutions in the downsampling blocks (assembled-cnn style) + * 'max_strided' - on both stem MaxPool and strided convolutions (zhang2019shiftinvar style for ResNets) + """ def __init__(self, block, layers, num_classes=1000, in_chans=3, cardinality=1, base_width=64, stem_width=64, stem_type='', block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0., - drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): + drop_block_rate=0., global_pool='avg', blur='', zero_init_last_bn=True, block_args=None): block_args = block_args or dict() self.num_classes = num_classes deep_stem = 'deep' in stem_type @@ -359,6 +391,7 @@ class ResNet(nn.Module): self.base_width = base_width self.drop_rate = drop_rate self.expansion = block.expansion + self.blur = 'strided' in blur super(ResNet, self).__init__() # Stem @@ -379,7 +412,13 @@ class ResNet(nn.Module): self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.act1 = act_layer(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + # Stem Blur + if 'max' in blur : + self.maxpool = nn.Sequential(*[ + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + BlurPool2d(channels=self.inplanes)]) + else : + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Feature Blocks dp = DropPath(drop_path_rate) if drop_path_rate else None @@ -432,7 +471,7 @@ class ResNet(nn.Module): block_kwargs = dict( cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, dilation=dilation, **kwargs) - layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)] + layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, blur=self.blur, **block_kwargs)] self.inplanes = planes * block.expansion layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)] @@ -1022,3 +1061,21 @@ def ecaresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model + +@register_model +def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. With original style blur + """ + default_cfg = default_cfgs['resnetblur18'] + model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, blur='max_strided',**kwargs) + model.default_cfg = default_cfg + return model + +@register_model +def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-50 model. With assembled-cnn style blur + """ + default_cfg = default_cfgs['resnetblur18'] + model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, blur='strided', **kwargs) + model.default_cfg = default_cfg + return model \ No newline at end of file