diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index d283deb8..2469acd2 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -200,8 +200,8 @@ class TResNet(nn.Module): self.feature_info = [ dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D? - dict(num_chs=self.planes, reduction=4, module='body.layer1'), - dict(num_chs=self.planes * 2, reduction=8, module='body.layer2'), + dict(num_chs=self.planes * (Bottleneck.expansion if v2 else 1), reduction=4, module='body.layer1'), + dict(num_chs=self.planes * 2 * (Bottleneck.expansion if v2 else 1), reduction=8, module='body.layer2'), dict(num_chs=self.planes * 4 * Bottleneck.expansion, reduction=16, module='body.layer3'), dict(num_chs=self.planes * 8 * Bottleneck.expansion, reduction=32, module='body.layer4'), ]