diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index da019075..285863f5 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -193,11 +193,12 @@ class Mixed_7a(nn.Module): class Block8(nn.Module): - def __init__(self, scale=1.0, noReLU=False): + __constants__ = ['relu'] # for pre 1.4 torchscript compat + + def __init__(self, scale=1.0, no_relu=False): super(Block8, self).__init__() self.scale = scale - self.noReLU = noReLU self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) @@ -208,8 +209,7 @@ class Block8(nn.Module): ) self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) - if not self.noReLU: - self.relu = nn.ReLU(inplace=False) + self.relu = None if no_relu else nn.ReLU(inplace=False) def forward(self, x): x0 = self.branch0(x) @@ -217,7 +217,7 @@ class Block8(nn.Module): out = torch.cat((x0, x1), 1) out = self.conv2d(out) out = out * self.scale + x - if not self.noReLU: + if self.relu is not None: out = self.relu(out) return out @@ -284,7 +284,7 @@ class InceptionResnetV2(nn.Module): Block8(scale=0.20), Block8(scale=0.20) ) - self.block8 = Block8(noReLU=True) + self.block8 = Block8(no_relu=True) self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) # NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC