|
|
|
@ -193,11 +193,12 @@ class Mixed_7a(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Block8(nn.Module):
|
|
|
|
|
def __init__(self, scale=1.0, noReLU=False):
|
|
|
|
|
__constants__ = ['relu'] # for pre 1.4 torchscript compat
|
|
|
|
|
|
|
|
|
|
def __init__(self, scale=1.0, no_relu=False):
|
|
|
|
|
super(Block8, self).__init__()
|
|
|
|
|
|
|
|
|
|
self.scale = scale
|
|
|
|
|
self.noReLU = noReLU
|
|
|
|
|
|
|
|
|
|
self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
|
|
|
|
|
|
|
|
|
@ -208,8 +209,7 @@ class Block8(nn.Module):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
|
|
|
|
|
if not self.noReLU:
|
|
|
|
|
self.relu = nn.ReLU(inplace=False)
|
|
|
|
|
self.relu = None if no_relu else nn.ReLU(inplace=False)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x0 = self.branch0(x)
|
|
|
|
@ -217,7 +217,7 @@ class Block8(nn.Module):
|
|
|
|
|
out = torch.cat((x0, x1), 1)
|
|
|
|
|
out = self.conv2d(out)
|
|
|
|
|
out = out * self.scale + x
|
|
|
|
|
if not self.noReLU:
|
|
|
|
|
if self.relu is not None:
|
|
|
|
|
out = self.relu(out)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
@ -284,7 +284,7 @@ class InceptionResnetV2(nn.Module):
|
|
|
|
|
Block8(scale=0.20),
|
|
|
|
|
Block8(scale=0.20)
|
|
|
|
|
)
|
|
|
|
|
self.block8 = Block8(noReLU=True)
|
|
|
|
|
self.block8 = Block8(no_relu=True)
|
|
|
|
|
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
|
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
|
# NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC
|
|
|
|
|