Fix ResNet based models to work w/ norm layers w/o affine params. Reformat long arg lists into vertical form.

pull/1612/head
Ross Wightman 2 years ago
parent d5aa17e415
commit 6902c48a5f

@ -962,9 +962,21 @@ class BasicBlock(nn.Module):
"""
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0,
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
group_size=None,
bottle_ratio=1.0,
downsample='avg',
attn_last=True,
linear_out=False,
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(BasicBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -983,7 +995,7 @@ class BasicBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None:
if zero_init_last and self.shortcut is not None and getattr(self.conv2_kxk.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
@ -1005,9 +1017,23 @@ class BottleneckBlock(nn.Module):
"""
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False,
layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
bottle_ratio=1.,
group_size=None,
downsample='avg',
attn_last=False,
linear_out=False,
extra_conv=False,
bottle_in=False,
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(BottleneckBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
@ -1031,7 +1057,7 @@ class BottleneckBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None:
if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv3_1x1.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
@ -1063,9 +1089,21 @@ class DarkBlock(nn.Module):
"""
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
bottle_ratio=1.0,
group_size=None,
downsample='avg',
attn_last=True,
linear_out=False,
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(DarkBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -1085,7 +1123,7 @@ class DarkBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None:
if zero_init_last and self.shortcut is not None and getattr(self.conv2_kxk.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
@ -1114,9 +1152,21 @@ class EdgeBlock(nn.Module):
"""
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None,
drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
bottle_ratio=1.0,
group_size=None,
downsample='avg',
attn_last=False,
linear_out=False,
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(EdgeBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -1135,7 +1185,7 @@ class EdgeBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None:
if zero_init_last and self.shortcut is not None and getattr(self.conv2_1x1.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv2_1x1.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
@ -1162,8 +1212,19 @@ class RepVggBlock(nn.Module):
"""
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
bottle_ratio=1.0,
group_size=None,
downsample='',
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(RepVggBlock, self).__init__()
layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
@ -1204,9 +1265,24 @@ class SelfAttnBlock(nn.Module):
"""
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True,
feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
bottle_ratio=1.,
group_size=None,
downsample='avg',
extra_conv=False,
linear_out=False,
bottle_in=False,
post_attn_na=True,
feat_size=None,
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(SelfAttnBlock, self).__init__()
assert layers is not None
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
@ -1233,7 +1309,7 @@ class SelfAttnBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None:
if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv3_1x1.bn.weight)
if hasattr(self.self_attn, 'reset_parameters'):
self.self_attn.reset_parameters()
@ -1274,8 +1350,17 @@ def create_block(block: Union[str, nn.Module], **kwargs):
class Stem(nn.Sequential):
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
self,
in_chs,
out_chs,
kernel_size=3,
stride=4,
pool='maxpool',
num_rep=3,
num_act=None,
chs_decay=0.5,
layers: LayerFn = None,
):
super().__init__()
assert stride in (2, 4)
layers = layers or LayerFn()
@ -1319,7 +1404,14 @@ class Stem(nn.Sequential):
assert curr_stride == stride
def create_byob_stem(in_chs, out_chs, stem_type='', pool_type='', feat_prefix='stem', layers: LayerFn = None):
def create_byob_stem(
in_chs,
out_chs,
stem_type='',
pool_type='',
feat_prefix='stem',
layers: LayerFn = None,
):
layers = layers or LayerFn()
assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', '7x7', '3x3')
if 'quad' in stem_type:
@ -1407,10 +1499,14 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo
def create_byob_stages(
cfg: ByoModelCfg, drop_path_rate: float, output_stride: int, stem_feat: Dict[str, Any],
cfg: ByoModelCfg,
drop_path_rate: float,
output_stride: int,
stem_feat: Dict[str, Any],
feat_size: Optional[int] = None,
layers: Optional[LayerFn] = None,
block_kwargs_fn: Optional[Callable] = update_block_kwargs):
block_kwargs_fn: Optional[Callable] = update_block_kwargs,
):
layers = layers or LayerFn()
feature_info = []
@ -1485,8 +1581,17 @@ class ByobNet(nn.Module):
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
"""
def __init__(
self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.):
self,
cfg: ByoModelCfg,
num_classes=1000,
in_chans=3,
global_pool='avg',
output_stride=32,
zero_init_last=True,
img_size=None,
drop_rate=0.,
drop_path_rate=0.,
):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate

@ -51,9 +51,21 @@ class Bottle2neck(nn.Module):
expansion = 4
def __init__(
self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None,
act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_):
self,
inplanes,
planes,
stride=1,
downsample=None,
cardinality=1,
base_width=26,
scale=4,
dilation=1,
first_dilation=None,
act_layer=nn.ReLU,
norm_layer=None,
attn_layer=None,
**_,
):
super(Bottle2neck, self).__init__()
self.scale = scale
self.is_first = stride > 1 or downsample is not None
@ -89,7 +101,8 @@ class Bottle2neck(nn.Module):
self.downsample = downsample
def zero_init_last(self):
nn.init.zeros_(self.bn3.weight)
if getattr(self.bn3, 'weight', None) is not None:
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
shortcut = x

@ -57,10 +57,27 @@ class ResNestBottleneck(nn.Module):
expansion = 4
def __init__(
self, inplanes, planes, stride=1, downsample=None,
radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
self,
inplanes,
planes,
stride=1,
downsample=None,
radix=1,
cardinality=1,
base_width=64,
avd=False,
avd_first=False,
is_first=False,
reduce_first=1,
dilation=1,
first_dilation=None,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
attn_layer=None,
aa_layer=None,
drop_block=None,
drop_path=None,
):
super(ResNestBottleneck, self).__init__()
assert reduce_first == 1 # not supported
assert attn_layer is None # not supported
@ -103,7 +120,8 @@ class ResNestBottleneck(nn.Module):
self.downsample = downsample
def zero_init_last(self):
nn.init.zeros_(self.bn3.weight)
if getattr(self.bn3, 'weight', None) is not None:
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
shortcut = x

@ -337,9 +337,23 @@ class BasicBlock(nn.Module):
expansion = 1
def __init__(
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
self,
inplanes,
planes,
stride=1,
downsample=None,
cardinality=1,
base_width=64,
reduce_first=1,
dilation=1,
first_dilation=None,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
attn_layer=None,
aa_layer=None,
drop_block=None,
drop_path=None,
):
super(BasicBlock, self).__init__()
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
@ -370,7 +384,8 @@ class BasicBlock(nn.Module):
self.drop_path = drop_path
def zero_init_last(self):
nn.init.zeros_(self.bn2.weight)
if getattr(self.bn2, 'weight', None) is not None:
nn.init.zeros_(self.bn2.weight)
def forward(self, x):
shortcut = x
@ -402,9 +417,23 @@ class Bottleneck(nn.Module):
expansion = 4
def __init__(
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
self,
inplanes,
planes,
stride=1,
downsample=None,
cardinality=1,
base_width=64,
reduce_first=1,
dilation=1,
first_dilation=None,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
attn_layer=None,
aa_layer=None,
drop_block=None,
drop_path=None,
):
super(Bottleneck, self).__init__()
width = int(math.floor(planes * (base_width / 64)) * cardinality)
@ -437,7 +466,8 @@ class Bottleneck(nn.Module):
self.drop_path = drop_path
def zero_init_last(self):
nn.init.zeros_(self.bn3.weight)
if getattr(self.bn3, 'weight', None) is not None:
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
shortcut = x
@ -508,8 +538,18 @@ def drop_blocks(drop_prob=0.):
def make_blocks(
block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32,
down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs):
block_fn,
channels,
block_repeats,
inplanes,
reduce_first=1,
output_stride=32,
down_kernel_size=1,
avg_down=False,
drop_block_rate=0.,
drop_path_rate=0.,
**kwargs,
):
stages = []
feature_info = []
net_num_blocks = sum(block_repeats)
@ -528,8 +568,14 @@ def make_blocks(
downsample = None
if stride != 1 or inplanes != planes * block_fn.expansion:
down_kwargs = dict(
in_channels=inplanes, out_channels=planes * block_fn.expansion, kernel_size=down_kernel_size,
stride=stride, dilation=dilation, first_dilation=prev_dilation, norm_layer=kwargs.get('norm_layer'))
in_channels=inplanes,
out_channels=planes * block_fn.expansion,
kernel_size=down_kernel_size,
stride=stride,
dilation=dilation,
first_dilation=prev_dilation,
norm_layer=kwargs.get('norm_layer'),
)
downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs)
block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs)
@ -609,10 +655,30 @@ class ResNet(nn.Module):
"""
def __init__(
self, block, layers, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg',
cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, block_reduce_first=1,
down_kernel_size=1, avg_down=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None,
drop_rate=0.0, drop_path_rate=0., drop_block_rate=0., zero_init_last=True, block_args=None):
self,
block,
layers,
num_classes=1000,
in_chans=3,
output_stride=32,
global_pool='avg',
cardinality=1,
base_width=64,
stem_width=64,
stem_type='',
replace_stem_pool=False,
block_reduce_first=1,
down_kernel_size=1,
avg_down=False,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
aa_layer=None,
drop_rate=0.0,
drop_path_rate=0.,
drop_block_rate=0.,
zero_init_last=True,
block_args=None,
):
super(ResNet, self).__init__()
block_args = block_args or dict()
assert output_stride in (8, 16, 32)
@ -663,10 +729,23 @@ class ResNet(nn.Module):
# Feature Blocks
channels = [64, 128, 256, 512]
stage_modules, stage_feature_info = make_blocks(
block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width,
output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down,
down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args)
block,
channels,
layers,
inplanes,
cardinality=cardinality,
base_width=base_width,
output_stride=output_stride,
reduce_first=block_reduce_first,
avg_down=avg_down,
down_kernel_size=down_kernel_size,
act_layer=act_layer,
norm_layer=norm_layer,
aa_layer=aa_layer,
drop_block_rate=drop_block_rate,
drop_path_rate=drop_path_rate,
**block_args,
)
for stage in stage_modules:
self.add_module(*stage) # layer1, layer2, etc
self.feature_info.extend(stage_feature_info)
@ -687,9 +766,6 @@ class ResNet(nn.Module):
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
if zero_init_last:
for m in self.modules():
if hasattr(m, 'zero_init_last'):

@ -155,8 +155,20 @@ class PreActBottleneck(nn.Module):
"""
def __init__(
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
self,
in_chs,
out_chs=None,
bottle_ratio=0.25,
stride=1,
dilation=1,
first_dilation=None,
groups=1,
act_layer=None,
conv_layer=None,
norm_layer=None,
proj_layer=None,
drop_path_rate=0.,
):
super().__init__()
first_dilation = first_dilation or dilation
conv_layer = conv_layer or StdConv2d
@ -202,8 +214,20 @@ class Bottleneck(nn.Module):
"""Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT.
"""
def __init__(
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
self,
in_chs,
out_chs=None,
bottle_ratio=0.25,
stride=1,
dilation=1,
first_dilation=None,
groups=1,
act_layer=None,
conv_layer=None,
norm_layer=None,
proj_layer=None,
drop_path_rate=0.,
):
super().__init__()
first_dilation = first_dilation or dilation
act_layer = act_layer or nn.ReLU
@ -229,7 +253,8 @@ class Bottleneck(nn.Module):
self.act3 = act_layer(inplace=True)
def zero_init_last(self):
nn.init.zeros_(self.norm3.weight)
if getattr(self.norm3, 'weight', None) is not None:
nn.init.zeros_(self.norm3.weight)
def forward(self, x):
# shortcut branch
@ -283,9 +308,22 @@ class DownsampleAvg(nn.Module):
class ResNetStage(nn.Module):
"""ResNet Stage."""
def __init__(
self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1,
avg_down=False, block_dpr=None, block_fn=PreActBottleneck,
act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs):
self,
in_chs,
out_chs,
stride,
dilation,
depth,
bottle_ratio=0.25,
groups=1,
avg_down=False,
block_dpr=None,
block_fn=PreActBottleneck,
act_layer=None,
conv_layer=None,
norm_layer=None,
**block_kwargs,
):
super(ResNetStage, self).__init__()
first_dilation = 1 if dilation in (1, 2) else 2
layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer)
@ -313,8 +351,13 @@ def is_stem_deep(stem_type):
def create_resnetv2_stem(
in_chs, out_chs=64, stem_type='', preact=True,
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
in_chs,
out_chs=64,
stem_type='',
preact=True,
conv_layer=StdConv2d,
norm_layer=partial(GroupNormAct, num_groups=32),
):
stem = OrderedDict()
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered')
@ -357,11 +400,25 @@ class ResNetV2(nn.Module):
"""
def __init__(
self, layers, channels=(256, 512, 1024, 2048),
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
drop_rate=0., drop_path_rate=0., zero_init_last=False):
self,
layers,
channels=(256, 512, 1024, 2048),
num_classes=1000,
in_chans=3,
global_pool='avg',
output_stride=32,
width_factor=1,
stem_chs=64,
stem_type='',
avg_down=False,
preact=True,
act_layer=nn.ReLU,
conv_layer=StdConv2d,
norm_layer=partial(GroupNormAct, num_groups=32),
drop_rate=0.,
drop_path_rate=0.,
zero_init_last=False,
):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
@ -387,8 +444,18 @@ class ResNetV2(nn.Module):
dilation *= stride
stride = 1
stage = ResNetStage(
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
prev_chs,
out_chs,
stride=stride,
dilation=dilation,
depth=d,
avg_down=avg_down,
act_layer=act_layer,
conv_layer=conv_layer,
norm_layer=norm_layer,
block_dpr=bdpr,
block_fn=block_fn,
)
prev_chs = out_chs
curr_stride *= stride
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')]

@ -47,9 +47,24 @@ class SelectiveKernelBasic(nn.Module):
expansion = 1
def __init__(
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
self,
inplanes,
planes,
stride=1,
downsample=None,
cardinality=1,
base_width=64,
sk_kwargs=None,
reduce_first=1,
dilation=1,
first_dilation=None,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
attn_layer=None,
aa_layer=None,
drop_block=None,
drop_path=None,
):
super(SelectiveKernelBasic, self).__init__()
sk_kwargs = sk_kwargs or {}
@ -71,7 +86,8 @@ class SelectiveKernelBasic(nn.Module):
self.drop_path = drop_path
def zero_init_last(self):
nn.init.zeros_(self.conv2.bn.weight)
if getattr(self.conv2.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv2.bn.weight)
def forward(self, x):
shortcut = x
@ -92,9 +108,24 @@ class SelectiveKernelBottleneck(nn.Module):
expansion = 4
def __init__(
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, sk_kwargs=None,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
self,
inplanes,
planes,
stride=1,
downsample=None,
cardinality=1,
base_width=64,
sk_kwargs=None,
reduce_first=1,
dilation=1,
first_dilation=None,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
attn_layer=None,
aa_layer=None,
drop_block=None,
drop_path=None,
):
super(SelectiveKernelBottleneck, self).__init__()
sk_kwargs = sk_kwargs or {}
@ -115,7 +146,8 @@ class SelectiveKernelBottleneck(nn.Module):
self.drop_path = drop_path
def zero_init_last(self):
nn.init.zeros_(self.conv3.bn.weight)
if getattr(self.conv3.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv3.bn.weight)
def forward(self, x):
shortcut = x

Loading…
Cancel
Save