From d55bcc0fee3adac6a814beef5be3e871902b4b27 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 16 Jan 2021 16:32:03 -0800 Subject: [PATCH] Finishing adding stochastic depth support to BiT ResNetV2 models --- timm/models/resnetv2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index f51d6357..1acc5eb0 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -249,6 +249,7 @@ class Bottleneck(nn.Module): x = self.norm2(x) x = self.conv3(x) x = self.norm3(x) + x = self.drop_path(x) x = self.act3(x + shortcut) return x @@ -366,9 +367,10 @@ class ResNetV2(nn.Module): prev_chs = stem_chs curr_stride = 4 dilation = 1 + block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)] block_fn = PreActBottleneck if preact else Bottleneck self.stages = nn.Sequential() - for stage_idx, (d, c) in enumerate(zip(layers, channels)): + for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)): out_chs = make_div(c * wf) stride = 1 if stage_idx == 0 else 2 if curr_stride >= output_stride: @@ -376,7 +378,7 @@ class ResNetV2(nn.Module): stride = 1 stage = ResNetStage( prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down, - act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_fn=block_fn) + act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn) prev_chs = out_chs curr_stride *= stride feat_name = f'stages.{stage_idx}'