diff --git a/timm/models/resnet.py b/timm/models/resnet.py index fe10ff22..24d1bf3b 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -601,8 +601,8 @@ class ResNet(nn.Module): else: self.maxpool = nn.Sequential(*[ nn.Conv2d(inplanes, inplanes, 3, stride=2, padding=1), - nn.BatchNorm2d(inplanes), - nn.ReLU() + norm_layer(inplanes), + act_layer(inplace=True) ]) # Feature Blocks