From a27f4aec4aaa22c6a6e82c7d8a9a69d73176525e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 31 May 2021 14:06:34 -0700 Subject: [PATCH] Missed args for skresnext w/ refactoring. --- timm/models/layers/selective_kernel.py | 2 +- timm/models/sknet.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py index 246f72a6..bf7df4d2 100644 --- a/timm/models/layers/selective_kernel.py +++ b/timm/models/layers/selective_kernel.py @@ -49,7 +49,7 @@ class SelectiveKernelAttn(nn.Module): class SelectiveKernel(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, - rd_ratio=1./16, rd_channels=None, min_rd_channels=16, rd_divisor=8, keep_3x3=True, split_input=True, + rd_ratio=1./16, rd_channels=None, min_rd_channels=32, rd_divisor=8, keep_3x3=True, split_input=True, drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): """ Selective Kernel Convolution Module diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 82ca5bfe..bba8bcf9 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -207,8 +207,9 @@ def skresnext50_32x4d(pretrained=False, **kwargs): """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to the SKNet-50 model in the Select Kernel Paper """ + sk_kwargs = dict(min_rd_channels=32, rd_ratio=1/16, split_input=False) model_args = dict( block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, - zero_init_last_bn=False, **kwargs) + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) return _create_skresnet('skresnext50_32x4d', pretrained, **model_args)