|
|
|
@ -42,10 +42,8 @@ for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd w
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SeparableConv2d(nn.Module):
|
|
|
|
|
def __init__(self, inplanes, planes, kernel_size=3, stride=1,
|
|
|
|
|
dilation=1, bias=False, norm_layer=None, norm_kwargs=None):
|
|
|
|
|
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None):
|
|
|
|
|
super(SeparableConv2d, self).__init__()
|
|
|
|
|
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
|
|
|
|
self.kernel_size = kernel_size
|
|
|
|
|
self.dilation = dilation
|
|
|
|
|
|
|
|
|
@ -54,7 +52,7 @@ class SeparableConv2d(nn.Module):
|
|
|
|
|
self.conv_dw = nn.Conv2d(
|
|
|
|
|
inplanes, inplanes, kernel_size, stride=stride,
|
|
|
|
|
padding=padding, dilation=dilation, groups=inplanes, bias=bias)
|
|
|
|
|
self.bn = norm_layer(num_features=inplanes, **norm_kwargs)
|
|
|
|
|
self.bn = norm_layer(num_features=inplanes)
|
|
|
|
|
# pointwise convolution
|
|
|
|
|
self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias)
|
|
|
|
|
|
|
|
|
@ -66,10 +64,8 @@ class SeparableConv2d(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
|
|
def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True,
|
|
|
|
|
norm_layer=None, norm_kwargs=None, ):
|
|
|
|
|
def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, norm_layer=None):
|
|
|
|
|
super(Block, self).__init__()
|
|
|
|
|
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
|
|
|
|
if isinstance(planes, (list, tuple)):
|
|
|
|
|
assert len(planes) == 3
|
|
|
|
|
else:
|
|
|
|
@ -80,7 +76,7 @@ class Block(nn.Module):
|
|
|
|
|
self.skip = nn.Sequential()
|
|
|
|
|
self.skip.add_module('conv1', nn.Conv2d(
|
|
|
|
|
inplanes, outplanes, 1, stride=stride, bias=False)),
|
|
|
|
|
self.skip.add_module('bn1', norm_layer(num_features=outplanes, **norm_kwargs))
|
|
|
|
|
self.skip.add_module('bn1', norm_layer(num_features=outplanes))
|
|
|
|
|
else:
|
|
|
|
|
self.skip = None
|
|
|
|
|
|
|
|
|
@ -88,9 +84,8 @@ class Block(nn.Module):
|
|
|
|
|
for i in range(3):
|
|
|
|
|
rep['act%d' % (i + 1)] = nn.ReLU(inplace=True)
|
|
|
|
|
rep['conv%d' % (i + 1)] = SeparableConv2d(
|
|
|
|
|
inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation,
|
|
|
|
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
|
|
|
|
rep['bn%d' % (i + 1)] = norm_layer(planes[i], **norm_kwargs)
|
|
|
|
|
inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, norm_layer=norm_layer)
|
|
|
|
|
rep['bn%d' % (i + 1)] = norm_layer(planes[i])
|
|
|
|
|
inplanes = planes[i]
|
|
|
|
|
|
|
|
|
|
if not start_with_relu:
|
|
|
|
@ -115,74 +110,63 @@ class Xception65(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
|
|
|
|
|
norm_kwargs=None, drop_rate=0., global_pool='avg'):
|
|
|
|
|
drop_rate=0., global_pool='avg'):
|
|
|
|
|
super(Xception65, self).__init__()
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
|
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
|
|
|
|
if output_stride == 32:
|
|
|
|
|
entry_block3_stride = 2
|
|
|
|
|
exit_block20_stride = 2
|
|
|
|
|
middle_block_dilation = 1
|
|
|
|
|
exit_block_dilations = (1, 1)
|
|
|
|
|
middle_dilation = 1
|
|
|
|
|
exit_dilation = (1, 1)
|
|
|
|
|
elif output_stride == 16:
|
|
|
|
|
entry_block3_stride = 2
|
|
|
|
|
exit_block20_stride = 1
|
|
|
|
|
middle_block_dilation = 1
|
|
|
|
|
exit_block_dilations = (1, 2)
|
|
|
|
|
middle_dilation = 1
|
|
|
|
|
exit_dilation = (1, 2)
|
|
|
|
|
elif output_stride == 8:
|
|
|
|
|
entry_block3_stride = 1
|
|
|
|
|
exit_block20_stride = 1
|
|
|
|
|
middle_block_dilation = 2
|
|
|
|
|
exit_block_dilations = (2, 4)
|
|
|
|
|
middle_dilation = 2
|
|
|
|
|
exit_dilation = (2, 4)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
# Entry flow
|
|
|
|
|
self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False)
|
|
|
|
|
self.bn1 = norm_layer(num_features=32, **norm_kwargs)
|
|
|
|
|
self.bn1 = norm_layer(num_features=32)
|
|
|
|
|
self.act1 = nn.ReLU(inplace=True)
|
|
|
|
|
|
|
|
|
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
|
|
|
|
self.bn2 = norm_layer(num_features=64)
|
|
|
|
|
self.act2 = nn.ReLU(inplace=True)
|
|
|
|
|
|
|
|
|
|
self.block1 = Block(
|
|
|
|
|
64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
|
|
|
|
self.block1 = Block(64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer)
|
|
|
|
|
self.block1_act = nn.ReLU(inplace=True)
|
|
|
|
|
self.block2 = Block(
|
|
|
|
|
128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
|
|
|
|
self.block3 = Block(
|
|
|
|
|
256, 728, stride=entry_block3_stride, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
|
|
|
|
self.block2 = Block(128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer)
|
|
|
|
|
self.block3 = Block(256, 728, stride=entry_block3_stride, norm_layer=norm_layer)
|
|
|
|
|
|
|
|
|
|
# Middle flow
|
|
|
|
|
self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
|
|
|
|
|
728, 728, stride=1, dilation=middle_block_dilation,
|
|
|
|
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs)) for i in range(4, 20)]))
|
|
|
|
|
728, 728, stride=1, dilation=middle_dilation, norm_layer=norm_layer)) for i in range(4, 20)]))
|
|
|
|
|
|
|
|
|
|
# Exit flow
|
|
|
|
|
self.block20 = Block(
|
|
|
|
|
728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_block_dilations[0],
|
|
|
|
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
|
|
|
|
728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_dilation[0], norm_layer=norm_layer)
|
|
|
|
|
self.block20_act = nn.ReLU(inplace=True)
|
|
|
|
|
|
|
|
|
|
self.conv3 = SeparableConv2d(
|
|
|
|
|
1024, 1536, 3, stride=1, dilation=exit_block_dilations[1],
|
|
|
|
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
|
|
|
|
self.bn3 = norm_layer(num_features=1536, **norm_kwargs)
|
|
|
|
|
self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
|
|
|
|
|
self.bn3 = norm_layer(num_features=1536)
|
|
|
|
|
self.act3 = nn.ReLU(inplace=True)
|
|
|
|
|
|
|
|
|
|
self.conv4 = SeparableConv2d(
|
|
|
|
|
1536, 1536, 3, stride=1, dilation=exit_block_dilations[1],
|
|
|
|
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
|
|
|
|
self.bn4 = norm_layer(num_features=1536, **norm_kwargs)
|
|
|
|
|
self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
|
|
|
|
|
self.bn4 = norm_layer(num_features=1536)
|
|
|
|
|
self.act4 = nn.ReLU(inplace=True)
|
|
|
|
|
|
|
|
|
|
self.num_features = 2048
|
|
|
|
|
self.conv5 = SeparableConv2d(
|
|
|
|
|
1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1],
|
|
|
|
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
|
|
|
|
self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs)
|
|
|
|
|
1536, self.num_features, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
|
|
|
|
|
self.bn5 = norm_layer(num_features=self.num_features)
|
|
|
|
|
self.act5 = nn.ReLU(inplace=True)
|
|
|
|
|
self.feature_info = [
|
|
|
|
|
dict(num_chs=64, reduction=2, module='act2'),
|
|
|
|
|