|
|
@ -49,7 +49,7 @@ class SelectiveKernelAttn(nn.Module):
|
|
|
|
class SelectiveKernel(nn.Module):
|
|
|
|
class SelectiveKernel(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1,
|
|
|
|
def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1,
|
|
|
|
rd_ratio=1./16, rd_channels=None, min_rd_channels=32, rd_divisor=8, keep_3x3=True, split_input=True,
|
|
|
|
rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True,
|
|
|
|
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
|
|
|
|
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
|
|
|
|
""" Selective Kernel Convolution Module
|
|
|
|
""" Selective Kernel Convolution Module
|
|
|
|
|
|
|
|
|
|
|
@ -68,7 +68,6 @@ class SelectiveKernel(nn.Module):
|
|
|
|
dilation (int): dilation for module as a whole, impacts dilation of each branch
|
|
|
|
dilation (int): dilation for module as a whole, impacts dilation of each branch
|
|
|
|
groups (int): number of groups for each branch
|
|
|
|
groups (int): number of groups for each branch
|
|
|
|
rd_ratio (int, float): reduction factor for attention features
|
|
|
|
rd_ratio (int, float): reduction factor for attention features
|
|
|
|
min_rd_channels (int): minimum attention feature channels
|
|
|
|
|
|
|
|
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
|
|
|
|
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
|
|
|
|
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
|
|
|
|
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
|
|
|
|
can be viewed as grouping by path, output expands to module out_channels count
|
|
|
|
can be viewed as grouping by path, output expands to module out_channels count
|
|
|
@ -103,8 +102,7 @@ class SelectiveKernel(nn.Module):
|
|
|
|
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
|
|
|
|
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
|
|
|
|
for k, d in zip(kernel_size, dilation)])
|
|
|
|
for k, d in zip(kernel_size, dilation)])
|
|
|
|
|
|
|
|
|
|
|
|
attn_channels = rd_channels or make_divisible(
|
|
|
|
attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)
|
|
|
|
out_channels * rd_ratio, min_value=min_rd_channels, divisor=rd_divisor)
|
|
|
|
|
|
|
|
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
|
|
|
|
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
|
|
|
|
self.drop_block = drop_block
|
|
|
|
self.drop_block = drop_block
|
|
|
|
|
|
|
|
|
|
|
|