Finishing adding stochastic depth support to BiT ResNetV2 models

pull/323/head
Ross Wightman 4 years ago
parent 0a1668f63e
commit d55bcc0fee

@ -249,6 +249,7 @@ class Bottleneck(nn.Module):
x = self.norm2(x)
x = self.conv3(x)
x = self.norm3(x)
x = self.drop_path(x)
x = self.act3(x + shortcut)
return x
@ -366,9 +367,10 @@ class ResNetV2(nn.Module):
prev_chs = stem_chs
curr_stride = 4
dilation = 1
block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
block_fn = PreActBottleneck if preact else Bottleneck
self.stages = nn.Sequential()
for stage_idx, (d, c) in enumerate(zip(layers, channels)):
for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
out_chs = make_div(c * wf)
stride = 1 if stage_idx == 0 else 2
if curr_stride >= output_stride:
@ -376,7 +378,7 @@ class ResNetV2(nn.Module):
stride = 1
stage = ResNetStage(
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_fn=block_fn)
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
prev_chs = out_chs
curr_stride *= stride
feat_name = f'stages.{stage_idx}'

Loading…
Cancel
Save