|
|
|
@ -249,6 +249,7 @@ class Bottleneck(nn.Module):
|
|
|
|
|
x = self.norm2(x)
|
|
|
|
|
x = self.conv3(x)
|
|
|
|
|
x = self.norm3(x)
|
|
|
|
|
x = self.drop_path(x)
|
|
|
|
|
x = self.act3(x + shortcut)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
@ -366,9 +367,10 @@ class ResNetV2(nn.Module):
|
|
|
|
|
prev_chs = stem_chs
|
|
|
|
|
curr_stride = 4
|
|
|
|
|
dilation = 1
|
|
|
|
|
block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
|
|
|
|
|
block_fn = PreActBottleneck if preact else Bottleneck
|
|
|
|
|
self.stages = nn.Sequential()
|
|
|
|
|
for stage_idx, (d, c) in enumerate(zip(layers, channels)):
|
|
|
|
|
for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
|
|
|
|
|
out_chs = make_div(c * wf)
|
|
|
|
|
stride = 1 if stage_idx == 0 else 2
|
|
|
|
|
if curr_stride >= output_stride:
|
|
|
|
@ -376,7 +378,7 @@ class ResNetV2(nn.Module):
|
|
|
|
|
stride = 1
|
|
|
|
|
stage = ResNetStage(
|
|
|
|
|
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
|
|
|
|
|
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_fn=block_fn)
|
|
|
|
|
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
|
|
|
|
|
prev_chs = out_chs
|
|
|
|
|
curr_stride *= stride
|
|
|
|
|
feat_name = f'stages.{stage_idx}'
|
|
|
|
|