|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
from .helpers import load_pretrained
|
|
|
|
from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
__all__ = ['NASNetALarge']
|
|
|
|
|
|
|
|
default_cfgs = {
|
|
|
|
'nasnetalarge': {
|
|
|
|
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth',
|
|
|
|
'input_size': (3, 331, 331),
|
|
|
|
'pool_size': (11, 11),
|
|
|
|
'crop_pct': 0.911,
|
|
|
|
'interpolation': 'bicubic',
|
|
|
|
'mean': (0.5, 0.5, 0.5),
|
|
|
|
'std': (0.5, 0.5, 0.5),
|
|
|
|
'num_classes': 1001,
|
|
|
|
'first_conv': 'conv0.conv',
|
|
|
|
'classifier': 'last_linear',
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class ActConvBn(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''):
|
|
|
|
super(ActConvBn, self).__init__()
|
|
|
|
self.act = nn.ReLU()
|
|
|
|
self.conv = create_conv2d(
|
|
|
|
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
|
|
self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.act(x)
|
|
|
|
x = self.conv(x)
|
|
|
|
x = self.bn(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class SeparableConv2d(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''):
|
|
|
|
super(SeparableConv2d, self).__init__()
|
|
|
|
self.depthwise_conv2d = create_conv2d(
|
|
|
|
in_channels, in_channels, kernel_size=kernel_size,
|
|
|
|
stride=stride, padding=padding, groups=in_channels)
|
|
|
|
self.pointwise_conv2d = create_conv2d(
|
|
|
|
in_channels, out_channels, kernel_size=1, padding=0)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.depthwise_conv2d(x)
|
|
|
|
x = self.pointwise_conv2d(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class BranchSeparables(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_type='', stem_cell=False):
|
|
|
|
super(BranchSeparables, self).__init__()
|
|
|
|
middle_channels = out_channels if stem_cell else in_channels
|
|
|
|
self.act_1 = nn.ReLU()
|
|
|
|
self.separable_1 = SeparableConv2d(
|
|
|
|
in_channels, middle_channels, kernel_size, stride=stride, padding=pad_type)
|
|
|
|
self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001, momentum=0.1)
|
|
|
|
self.act_2 = nn.ReLU(inplace=True)
|
|
|
|
self.separable_2 = SeparableConv2d(
|
|
|
|
middle_channels, out_channels, kernel_size, stride=1, padding=pad_type)
|
|
|
|
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.act_1(x)
|
|
|
|
x = self.separable_1(x)
|
|
|
|
x = self.bn_sep_1(x)
|
|
|
|
x = self.act_2(x)
|
|
|
|
x = self.separable_2(x)
|
|
|
|
x = self.bn_sep_2(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class CellStem0(nn.Module):
|
|
|
|
def __init__(self, stem_size, num_channels=42, pad_type=''):
|
|
|
|
super(CellStem0, self).__init__()
|
|
|
|
self.num_channels = num_channels
|
|
|
|
self.stem_size = stem_size
|
|
|
|
self.conv_1x1 = ActConvBn(self.stem_size, self.num_channels, 1, stride=1)
|
|
|
|
|
|
|
|
self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
|
|
|
|
self.comb_iter_0_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True)
|
|
|
|
|
|
|
|
self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
|
|
|
|
self.comb_iter_1_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True)
|
|
|
|
|
|
|
|
self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
|
|
|
|
self.comb_iter_2_right = BranchSeparables(self.stem_size, self.num_channels, 5, 2, pad_type, stem_cell=True)
|
|
|
|
|
|
|
|
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type)
|
|
|
|
self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x1 = self.conv_1x1(x)
|
|
|
|
|
|
|
|
x_comb_iter_0_left = self.comb_iter_0_left(x1)
|
|
|
|
x_comb_iter_0_right = self.comb_iter_0_right(x)
|
|
|
|
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
|
|
|
|
|
|
|
x_comb_iter_1_left = self.comb_iter_1_left(x1)
|
|
|
|
x_comb_iter_1_right = self.comb_iter_1_right(x)
|
|
|
|
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
|
|
|
|
|
|
|
x_comb_iter_2_left = self.comb_iter_2_left(x1)
|
|
|
|
x_comb_iter_2_right = self.comb_iter_2_right(x)
|
|
|
|
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
|
|
|
|
|
|
|
|
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
|
|
|
|
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
|
|
|
|
|
|
|
|
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
|
|
|
|
x_comb_iter_4_right = self.comb_iter_4_right(x1)
|
|
|
|
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
|
|
|
|
|
|
|
x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
|
|
|
|
return x_out
|
|
|
|
|
|
|
|
|
|
|
|
class CellStem1(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, stem_size, num_channels, pad_type=''):
|
|
|
|
super(CellStem1, self).__init__()
|
|
|
|
self.num_channels = num_channels
|
|
|
|
self.stem_size = stem_size
|
|
|
|
self.conv_1x1 = ActConvBn(2 * self.num_channels, self.num_channels, 1, stride=1)
|
|
|
|
|
|
|
|
self.act = nn.ReLU()
|
|
|
|
self.path_1 = nn.Sequential()
|
|
|
|
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
|
|
|
self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
|
|
|
|
|
|
|
|
self.path_2 = nn.Sequential()
|
|
|
|
self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1)))
|
|
|
|
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
|
|
|
self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
|
|
|
|
|
|
|
|
self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1)
|
|
|
|
|
|
|
|
self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
|
|
|
|
self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
|
|
|
|
self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
|
|
|
|
self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type)
|
|
|
|
self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
|
|
|
|
|
|
|
|
def forward(self, x_conv0, x_stem_0):
|
|
|
|
x_left = self.conv_1x1(x_stem_0)
|
|
|
|
|
|
|
|
x_relu = self.act(x_conv0)
|
|
|
|
# path 1
|
|
|
|
x_path1 = self.path_1(x_relu)
|
|
|
|
# path 2
|
|
|
|
x_path2 = self.path_2(x_relu)
|
|
|
|
# final path
|
|
|
|
x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
|
|
|
|
|
|
|
|
x_comb_iter_0_left = self.comb_iter_0_left(x_left)
|
|
|
|
x_comb_iter_0_right = self.comb_iter_0_right(x_right)
|
|
|
|
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
|
|
|
|
|
|
|
x_comb_iter_1_left = self.comb_iter_1_left(x_left)
|
|
|
|
x_comb_iter_1_right = self.comb_iter_1_right(x_right)
|
|
|
|
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
|
|
|
|
|
|
|
x_comb_iter_2_left = self.comb_iter_2_left(x_left)
|
|
|
|
x_comb_iter_2_right = self.comb_iter_2_right(x_right)
|
|
|
|
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
|
|
|
|
|
|
|
|
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
|
|
|
|
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
|
|
|
|
|
|
|
|
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
|
|
|
|
x_comb_iter_4_right = self.comb_iter_4_right(x_left)
|
|
|
|
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
|
|
|
|
|
|
|
x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
|
|
|
|
return x_out
|
|
|
|
|
|
|
|
|
|
|
|
class FirstCell(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
|
|
|
|
super(FirstCell, self).__init__()
|
|
|
|
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1)
|
|
|
|
|
|
|
|
self.act = nn.ReLU()
|
|
|
|
self.path_1 = nn.Sequential()
|
|
|
|
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
|
|
|
self.path_1.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False))
|
|
|
|
|
|
|
|
self.path_2 = nn.Sequential()
|
|
|
|
self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1)))
|
|
|
|
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
|
|
|
self.path_2.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False))
|
|
|
|
|
|
|
|
self.final_path_bn = nn.BatchNorm2d(out_chs_left * 2, eps=0.001, momentum=0.1)
|
|
|
|
|
|
|
|
self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
|
|
|
|
self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_1_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
|
|
|
|
self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
|
|
|
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
|
|
|
|
|
|
|
def forward(self, x, x_prev):
|
|
|
|
x_relu = self.act(x_prev)
|
|
|
|
x_path1 = self.path_1(x_relu)
|
|
|
|
x_path2 = self.path_2(x_relu)
|
|
|
|
x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
|
|
|
|
x_right = self.conv_1x1(x)
|
|
|
|
|
|
|
|
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
|
|
|
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
|
|
|
|
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
|
|
|
|
|
|
|
x_comb_iter_1_left = self.comb_iter_1_left(x_left)
|
|
|
|
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
|
|
|
|
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
|
|
|
|
|
|
|
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
|
|
|
|
x_comb_iter_2 = x_comb_iter_2_left + x_left
|
|
|
|
|
|
|
|
x_comb_iter_3_left = self.comb_iter_3_left(x_left)
|
|
|
|
x_comb_iter_3_right = self.comb_iter_3_right(x_left)
|
|
|
|
x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
|
|
|
|
|
|
|
|
x_comb_iter_4_left = self.comb_iter_4_left(x_right)
|
|
|
|
x_comb_iter_4 = x_comb_iter_4_left + x_right
|
|
|
|
|
|
|
|
x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
|
|
|
|
return x_out
|
|
|
|
|
|
|
|
|
|
|
|
class NormalCell(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
|
|
|
|
super(NormalCell, self).__init__()
|
|
|
|
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
|
|
|
|
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
|
|
|
|
self.comb_iter_0_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_1_left = BranchSeparables(out_chs_left, out_chs_left, 5, 1, pad_type)
|
|
|
|
self.comb_iter_1_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
|
|
|
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
|
|
|
|
|
|
|
def forward(self, x, x_prev):
|
|
|
|
x_left = self.conv_prev_1x1(x_prev)
|
|
|
|
x_right = self.conv_1x1(x)
|
|
|
|
|
|
|
|
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
|
|
|
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
|
|
|
|
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
|
|
|
|
|
|
|
x_comb_iter_1_left = self.comb_iter_1_left(x_left)
|
|
|
|
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
|
|
|
|
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
|
|
|
|
|
|
|
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
|
|
|
|
x_comb_iter_2 = x_comb_iter_2_left + x_left
|
|
|
|
|
|
|
|
x_comb_iter_3_left = self.comb_iter_3_left(x_left)
|
|
|
|
x_comb_iter_3_right = self.comb_iter_3_right(x_left)
|
|
|
|
x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
|
|
|
|
|
|
|
|
x_comb_iter_4_left = self.comb_iter_4_left(x_right)
|
|
|
|
x_comb_iter_4 = x_comb_iter_4_left + x_right
|
|
|
|
|
|
|
|
x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
|
|
|
|
return x_out
|
|
|
|
|
|
|
|
|
|
|
|
class ReductionCell0(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
|
|
|
|
super(ReductionCell0, self).__init__()
|
|
|
|
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
|
|
|
|
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
|
|
|
|
self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
|
|
|
|
self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
|
|
|
|
self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
|
|
|
self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
|
|
|
|
|
|
|
|
def forward(self, x, x_prev):
|
|
|
|
x_left = self.conv_prev_1x1(x_prev)
|
|
|
|
x_right = self.conv_1x1(x)
|
|
|
|
|
|
|
|
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
|
|
|
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
|
|
|
|
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
|
|
|
|
|
|
|
x_comb_iter_1_left = self.comb_iter_1_left(x_right)
|
|
|
|
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
|
|
|
|
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
|
|
|
|
|
|
|
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
|
|
|
|
x_comb_iter_2_right = self.comb_iter_2_right(x_left)
|
|
|
|
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
|
|
|
|
|
|
|
|
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
|
|
|
|
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
|
|
|
|
|
|
|
|
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
|
|
|
|
x_comb_iter_4_right = self.comb_iter_4_right(x_right)
|
|
|
|
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
|
|
|
|
|
|
|
x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
|
|
|
|
return x_out
|
|
|
|
|
|
|
|
|
|
|
|
class ReductionCell1(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
|
|
|
|
super(ReductionCell1, self).__init__()
|
|
|
|
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
|
|
|
|
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
|
|
|
|
self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
|
|
|
|
self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
|
|
|
|
self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
|
|
|
|
|
|
|
self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
|
|
|
self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
|
|
|
|
|
|
|
|
def forward(self, x, x_prev):
|
|
|
|
x_left = self.conv_prev_1x1(x_prev)
|
|
|
|
x_right = self.conv_1x1(x)
|
|
|
|
|
|
|
|
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
|
|
|
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
|
|
|
|
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
|
|
|
|
|
|
|
x_comb_iter_1_left = self.comb_iter_1_left(x_right)
|
|
|
|
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
|
|
|
|
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
|
|
|
|
|
|
|
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
|
|
|
|
x_comb_iter_2_right = self.comb_iter_2_right(x_left)
|
|
|
|
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
|
|
|
|
|
|
|
|
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
|
|
|
|
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
|
|
|
|
|
|
|
|
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
|
|
|
|
x_comb_iter_4_right = self.comb_iter_4_right(x_right)
|
|
|
|
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
|
|
|
|
|
|
|
x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
|
|
|
|
return x_out
|
|
|
|
|
|
|
|
|
|
|
|
class NASNetALarge(nn.Module):
|
|
|
|
"""NASNetALarge (6 @ 4032) """
|
|
|
|
|
|
|
|
def __init__(self, num_classes=1000, in_chans=1, stem_size=96, channel_multiplier=2,
|
|
|
|
num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'):
|
|
|
|
super(NASNetALarge, self).__init__()
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.stem_size = stem_size
|
|
|
|
self.num_features = num_features
|
|
|
|
self.channel_multiplier = channel_multiplier
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
assert output_stride == 32
|
|
|
|
|
|
|
|
channels = self.num_features // 24
|
|
|
|
# 24 is default value for the architecture
|
|
|
|
|
|
|
|
self.conv0 = ConvBnAct(
|
|
|
|
in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2,
|
|
|
|
norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None)
|
|
|
|
|
|
|
|
self.cell_stem_0 = CellStem0(
|
|
|
|
self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type)
|
|
|
|
self.cell_stem_1 = CellStem1(
|
|
|
|
self.stem_size, num_channels=channels // channel_multiplier, pad_type=pad_type)
|
|
|
|
|
|
|
|
self.cell_0 = FirstCell(
|
|
|
|
in_chs_left=channels, out_chs_left=channels // 2,
|
|
|
|
in_chs_right=2 * channels, out_chs_right=channels, pad_type=pad_type)
|
|
|
|
self.cell_1 = NormalCell(
|
|
|
|
in_chs_left=2 * channels, out_chs_left=channels,
|
|
|
|
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
|
|
|
|
self.cell_2 = NormalCell(
|
|
|
|
in_chs_left=6 * channels, out_chs_left=channels,
|
|
|
|
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
|
|
|
|
self.cell_3 = NormalCell(
|
|
|
|
in_chs_left=6 * channels, out_chs_left=channels,
|
|
|
|
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
|
|
|
|
self.cell_4 = NormalCell(
|
|
|
|
in_chs_left=6 * channels, out_chs_left=channels,
|
|
|
|
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
|
|
|
|
self.cell_5 = NormalCell(
|
|
|
|
in_chs_left=6 * channels, out_chs_left=channels,
|
|
|
|
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
|
|
|
|
|
|
|
|
self.reduction_cell_0 = ReductionCell0(
|
|
|
|
in_chs_left=6 * channels, out_chs_left=2 * channels,
|
|
|
|
in_chs_right=6 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
|
|
|
self.cell_6 = FirstCell(
|
|
|
|
in_chs_left=6 * channels, out_chs_left=channels,
|
|
|
|
in_chs_right=8 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
|
|
|
self.cell_7 = NormalCell(
|
|
|
|
in_chs_left=8 * channels, out_chs_left=2 * channels,
|
|
|
|
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
|
|
|
self.cell_8 = NormalCell(
|
|
|
|
in_chs_left=12 * channels, out_chs_left=2 * channels,
|
|
|
|
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
|
|
|
self.cell_9 = NormalCell(
|
|
|
|
in_chs_left=12 * channels, out_chs_left=2 * channels,
|
|
|
|
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
|
|
|
self.cell_10 = NormalCell(
|
|
|
|
in_chs_left=12 * channels, out_chs_left=2 * channels,
|
|
|
|
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
|
|
|
self.cell_11 = NormalCell(
|
|
|
|
in_chs_left=12 * channels, out_chs_left=2 * channels,
|
|
|
|
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
|
|
|
|
|
|
|
self.reduction_cell_1 = ReductionCell1(
|
|
|
|
in_chs_left=12 * channels, out_chs_left=4 * channels,
|
|
|
|
in_chs_right=12 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
|
|
|
self.cell_12 = FirstCell(
|
|
|
|
in_chs_left=12 * channels, out_chs_left=2 * channels,
|
|
|
|
in_chs_right=16 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
|
|
|
self.cell_13 = NormalCell(
|
|
|
|
in_chs_left=16 * channels, out_chs_left=4 * channels,
|
|
|
|
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
|
|
|
self.cell_14 = NormalCell(
|
|
|
|
in_chs_left=24 * channels, out_chs_left=4 * channels,
|
|
|
|
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
|
|
|
self.cell_15 = NormalCell(
|
|
|
|
in_chs_left=24 * channels, out_chs_left=4 * channels,
|
|
|
|
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
|
|
|
self.cell_16 = NormalCell(
|
|
|
|
in_chs_left=24 * channels, out_chs_left=4 * channels,
|
|
|
|
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
|
|
|
self.cell_17 = NormalCell(
|
|
|
|
in_chs_left=24 * channels, out_chs_left=4 * channels,
|
|
|
|
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
|
|
|
|
|
|
|
self.act = nn.ReLU(inplace=True)
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
|
|
|
|
|
|
|
def get_classifier(self):
|
|
|
|
return self.last_linear
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
if num_classes:
|
|
|
|
num_features = self.num_features * self.global_pool.feat_mult()
|
|
|
|
self.last_linear = nn.Linear(num_features, num_classes)
|
|
|
|
else:
|
|
|
|
self.last_linear = nn.Identity()
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
x_conv0 = self.conv0(x)
|
|
|
|
#0
|
|
|
|
|
|
|
|
x_stem_0 = self.cell_stem_0(x_conv0)
|
|
|
|
x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0)
|
|
|
|
#1
|
|
|
|
|
|
|
|
x_cell_0 = self.cell_0(x_stem_1, x_stem_0)
|
|
|
|
x_cell_1 = self.cell_1(x_cell_0, x_stem_1)
|
|
|
|
x_cell_2 = self.cell_2(x_cell_1, x_cell_0)
|
|
|
|
x_cell_3 = self.cell_3(x_cell_2, x_cell_1)
|
|
|
|
x_cell_4 = self.cell_4(x_cell_3, x_cell_2)
|
|
|
|
x_cell_5 = self.cell_5(x_cell_4, x_cell_3)
|
|
|
|
#2
|
|
|
|
|
|
|
|
x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4)
|
|
|
|
x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4)
|
|
|
|
x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0)
|
|
|
|
x_cell_8 = self.cell_8(x_cell_7, x_cell_6)
|
|
|
|
x_cell_9 = self.cell_9(x_cell_8, x_cell_7)
|
|
|
|
x_cell_10 = self.cell_10(x_cell_9, x_cell_8)
|
|
|
|
x_cell_11 = self.cell_11(x_cell_10, x_cell_9)
|
|
|
|
#3
|
|
|
|
|
|
|
|
x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10)
|
|
|
|
x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10)
|
|
|
|
x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1)
|
|
|
|
x_cell_14 = self.cell_14(x_cell_13, x_cell_12)
|
|
|
|
x_cell_15 = self.cell_15(x_cell_14, x_cell_13)
|
|
|
|
x_cell_16 = self.cell_16(x_cell_15, x_cell_14)
|
|
|
|
x_cell_17 = self.cell_17(x_cell_16, x_cell_15)
|
|
|
|
x = self.act(x_cell_17)
|
|
|
|
#4
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.global_pool(x).flatten(1)
|
|
|
|
if self.drop_rate > 0:
|
|
|
|
x = F.dropout(x, self.drop_rate, training=self.training)
|
|
|
|
x = self.last_linear(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def nasnetalarge(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
"""NASNet-A large model architecture.
|
|
|
|
"""
|
|
|
|
default_cfg = default_cfgs['nasnetalarge']
|
|
|
|
model = NASNetALarge(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
if pretrained:
|
|
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
|
|
|
|
return model
|