InceptionResNetV2 torchscript compatible

pull/82/head
Ross Wightman 5 years ago
parent 19d93fe454
commit f96b3e5e92

@ -193,11 +193,12 @@ class Mixed_7a(nn.Module):
class Block8(nn.Module):
def __init__(self, scale=1.0, noReLU=False):
__constants__ = ['relu'] # for pre 1.4 torchscript compat
def __init__(self, scale=1.0, no_relu=False):
super(Block8, self).__init__()
self.scale = scale
self.noReLU = noReLU
self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
@ -208,8 +209,7 @@ class Block8(nn.Module):
)
self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
if not self.noReLU:
self.relu = nn.ReLU(inplace=False)
self.relu = None if no_relu else nn.ReLU(inplace=False)
def forward(self, x):
x0 = self.branch0(x)
@ -217,7 +217,7 @@ class Block8(nn.Module):
out = torch.cat((x0, x1), 1)
out = self.conv2d(out)
out = out * self.scale + x
if not self.noReLU:
if self.relu is not None:
out = self.relu(out)
return out
@ -284,7 +284,7 @@ class InceptionResnetV2(nn.Module):
Block8(scale=0.20),
Block8(scale=0.20)
)
self.block8 = Block8(noReLU=True)
self.block8 = Block8(no_relu=True)
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
# NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC

Loading…
Cancel
Save