diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 06b1d178..1a864f87 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -5,6 +5,7 @@ from .resnet import * from .dpn import * from .senet import * from .xception import * +from .nasnet import * from .pnasnet import * from .gen_efficientnet import * from .inception_v3 import * diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py new file mode 100644 index 00000000..9caee809 --- /dev/null +++ b/timm/models/nasnet.py @@ -0,0 +1,612 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .registry import register_model +from .helpers import load_pretrained +from .adaptive_avgmax_pool import SelectAdaptivePool2d + + +__all__ = ['NASNetALarge'] + +default_cfgs = { + 'nasnetalarge': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth', + 'input_size': (3, 331, 331), + 'pool_size': (11, 11), + 'crop_pct': 0.875, + 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5), + 'num_classes': 1001, + 'first_conv': 'conv_0.conv', + 'classifier': 'last_linear', + }, +} + + +class MaxPoolPad(nn.Module): + + def __init__(self): + super(MaxPoolPad, self).__init__() + self.pad = nn.ZeroPad2d((1, 0, 1, 0)) + self.pool = nn.MaxPool2d(3, stride=2, padding=1) + + def forward(self, x): + x = self.pad(x) + x = self.pool(x) + x = x[:, :, 1:, 1:] + return x + + +class AvgPoolPad(nn.Module): + + def __init__(self, stride=2, padding=1): + super(AvgPoolPad, self).__init__() + self.pad = nn.ZeroPad2d((1, 0, 1, 0)) + self.pool = nn.AvgPool2d(3, stride=stride, padding=padding, count_include_pad=False) + + def forward(self, x): + x = self.pad(x) + x = self.pool(x) + x = x[:, :, 1:, 1:] + return x + + +class SeparableConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, dw_kernel, dw_stride, dw_padding, bias=False): + super(SeparableConv2d, self).__init__() + self.depthwise_conv2d = nn.Conv2d( + in_channels, in_channels, dw_kernel, + stride=dw_stride, padding=dw_padding, + bias=bias, groups=in_channels) + self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=bias) + + def forward(self, x): + x = self.depthwise_conv2d(x) + x = self.pointwise_conv2d(x) + return x + + +class BranchSeparables(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False): + super(BranchSeparables, self).__init__() + self.relu = nn.ReLU() + self.separable_1 = SeparableConv2d(in_channels, in_channels, kernel_size, stride, padding, bias=bias) + self.bn_sep_1 = nn.BatchNorm2d(in_channels, eps=0.001, momentum=0.1, affine=True) + self.relu1 = nn.ReLU() + self.separable_2 = SeparableConv2d(in_channels, out_channels, kernel_size, 1, padding, bias=bias) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) + + def forward(self, x): + x = self.relu(x) + x = self.separable_1(x) + x = self.bn_sep_1(x) + x = self.relu1(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class BranchSeparablesStem(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False): + super(BranchSeparablesStem, self).__init__() + self.relu = nn.ReLU() + self.separable_1 = SeparableConv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) + self.bn_sep_1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) + self.relu1 = nn.ReLU() + self.separable_2 = SeparableConv2d(out_channels, out_channels, kernel_size, 1, padding, bias=bias) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) + + def forward(self, x): + x = self.relu(x) + x = self.separable_1(x) + x = self.bn_sep_1(x) + x = self.relu1(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class BranchSeparablesReduction(BranchSeparables): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, z_padding=1, bias=False): + BranchSeparables.__init__(self, in_channels, out_channels, kernel_size, stride, padding, bias) + self.padding = nn.ZeroPad2d((z_padding, 0, z_padding, 0)) + + def forward(self, x): + x = self.relu(x) + x = self.padding(x) + x = self.separable_1(x) + x = x[:, :, 1:, 1:].contiguous() + x = self.bn_sep_1(x) + x = self.relu1(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class CellStem0(nn.Module): + def __init__(self, stem_size, num_channels=42): + super(CellStem0, self).__init__() + self.num_channels = num_channels + self.stem_size = stem_size + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True)) + + self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, 2) + self.comb_iter_0_right = BranchSeparablesStem(self.stem_size, self.num_channels, 7, 2, 3, bias=False) + + self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_1_right = BranchSeparablesStem(self.stem_size, self.num_channels, 7, 2, 3, bias=False) + + self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) + self.comb_iter_2_right = BranchSeparablesStem(self.stem_size, self.num_channels, 5, 2, 2, bias=False) + + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, 1, bias=False) + self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) + + def forward(self, x): + x1 = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x1) + x_comb_iter_0_right = self.comb_iter_0_right(x) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x1) + x_comb_iter_1_right = self.comb_iter_1_right(x) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x1) + x_comb_iter_2_right = self.comb_iter_2_right(x) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x1) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class CellStem1(nn.Module): + + def __init__(self, stem_size, num_channels): + super(CellStem1, self).__init__() + self.num_channels = num_channels + self.stem_size = stem_size + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(2*self.num_channels, self.num_channels, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True)) + + self.relu = nn.ReLU() + self.path_1 = nn.Sequential() + self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels//2, 1, stride=1, bias=False)) + self.path_2 = nn.ModuleList() + self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1))) + self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels//2, 1, stride=1, bias=False)) + + self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True) + + self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, 2, bias=False) + self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, 3, bias=False) + + self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, 3, bias=False) + + self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) + self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, 2, bias=False) + + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, 1, bias=False) + self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) + + def forward(self, x_conv0, x_stem_0): + x_left = self.conv_1x1(x_stem_0) + + x_relu = self.relu(x_conv0) + # path 1 + x_path1 = self.path_1(x_relu) + # path 2 + x_path2 = self.path_2.pad(x_relu) + x_path2 = x_path2[:, :, 1:, 1:] + x_path2 = self.path_2.avgpool(x_path2) + x_path2 = self.path_2.conv(x_path2) + # final path + x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + + x_comb_iter_0_left = self.comb_iter_0_left(x_left) + x_comb_iter_0_right = self.comb_iter_0_right(x_right) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_right) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_left) + x_comb_iter_2_right = self.comb_iter_2_right(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_left) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class FirstCell(nn.Module): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + super(FirstCell, self).__init__() + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) + + self.relu = nn.ReLU() + self.path_1 = nn.Sequential() + self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) + self.path_2 = nn.ModuleList() + self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1))) + self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_2.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) + + self.final_path_bn = nn.BatchNorm2d(out_channels_left * 2, eps=0.001, momentum=0.1, affine=True) + + self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False) + self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + + self.comb_iter_1_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False) + self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + + self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + + def forward(self, x, x_prev): + x_relu = self.relu(x_prev) + # path 1 + x_path1 = self.path_1(x_relu) + # path 2 + x_path2 = self.path_2.pad(x_relu) + x_path2 = x_path2[:, :, 1:, 1:] + x_path2 = self.path_2.avgpool(x_path2) + x_path2 = self.path_2.conv(x_path2) + # final path + x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_left + + x_comb_iter_3_left = self.comb_iter_3_left(x_left) + x_comb_iter_3_right = self.comb_iter_3_right(x_left) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_right + + x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class NormalCell(nn.Module): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + super(NormalCell, self).__init__() + self.conv_prev_1x1 = nn.Sequential() + self.conv_prev_1x1.add_module('relu', nn.ReLU()) + self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) + self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True)) + + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) + + self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False) + self.comb_iter_0_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False) + + self.comb_iter_1_left = BranchSeparables(out_channels_left, out_channels_left, 5, 1, 2, bias=False) + self.comb_iter_1_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False) + + self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_left + + x_comb_iter_3_left = self.comb_iter_3_left(x_left) + x_comb_iter_3_right = self.comb_iter_3_right(x_left) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_right + + x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class ReductionCell0(nn.Module): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + super(ReductionCell0, self).__init__() + self.conv_prev_1x1 = nn.Sequential() + self.conv_prev_1x1.add_module('relu', nn.ReLU()) + self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) + self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True)) + + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) + + self.comb_iter_0_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False) + self.comb_iter_0_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False) + + self.comb_iter_1_left = MaxPoolPad() + self.comb_iter_1_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False) + + self.comb_iter_2_left = AvgPoolPad() + self.comb_iter_2_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False) + + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + self.comb_iter_4_right = MaxPoolPad() + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_left) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class ReductionCell1(nn.Module): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): + super(ReductionCell1, self).__init__() + self.conv_prev_1x1 = nn.Sequential() + self.conv_prev_1x1.add_module('relu', nn.ReLU()) + self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) + self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True)) + + self.conv_1x1 = nn.Sequential() + self.conv_1x1.add_module('relu', nn.ReLU()) + self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) + self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) + + self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False) + self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False) + + self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) + self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False) + + self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) + self.comb_iter_2_right = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False) + + self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) + + self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) + self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_left) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class NASNetALarge(nn.Module): + """NASNetALarge (6 @ 4032) """ + + def __init__(self, num_classes=1000, in_chans=1, stem_size=96, num_features=4032, channel_multiplier=2, + drop_rate=0., global_pool='avg'): + super(NASNetALarge, self).__init__() + self.num_classes = num_classes + self.stem_size = stem_size + self.num_features = num_features + self.channel_multiplier = channel_multiplier + self.drop_rate = drop_rate + + channels = self.num_features // 24 + # 24 is default value for the architecture + + self.conv0 = nn.Sequential() + self.conv0.add_module('conv', nn.Conv2d( + in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2, bias=False)) + self.conv0.add_module('bn', nn.BatchNorm2d(self.stem_size, eps=0.001, momentum=0.1, affine=True)) + + self.cell_stem_0 = CellStem0(self.stem_size, num_channels=channels // (channel_multiplier ** 2)) + self.cell_stem_1 = CellStem1(self.stem_size, num_channels=channels // channel_multiplier) + + self.cell_0 = FirstCell(in_channels_left=channels, out_channels_left=channels//2, + in_channels_right=2*channels, out_channels_right=channels) + self.cell_1 = NormalCell(in_channels_left=2*channels, out_channels_left=channels, + in_channels_right=6*channels, out_channels_right=channels) + self.cell_2 = NormalCell(in_channels_left=6*channels, out_channels_left=channels, + in_channels_right=6*channels, out_channels_right=channels) + self.cell_3 = NormalCell(in_channels_left=6*channels, out_channels_left=channels, + in_channels_right=6*channels, out_channels_right=channels) + self.cell_4 = NormalCell(in_channels_left=6*channels, out_channels_left=channels, + in_channels_right=6*channels, out_channels_right=channels) + self.cell_5 = NormalCell(in_channels_left=6*channels, out_channels_left=channels, + in_channels_right=6*channels, out_channels_right=channels) + + self.reduction_cell_0 = ReductionCell0(in_channels_left=6*channels, out_channels_left=2*channels, + in_channels_right=6*channels, out_channels_right=2*channels) + + self.cell_6 = FirstCell(in_channels_left=6*channels, out_channels_left=channels, + in_channels_right=8*channels, out_channels_right=2*channels) + self.cell_7 = NormalCell(in_channels_left=8*channels, out_channels_left=2*channels, + in_channels_right=12*channels, out_channels_right=2*channels) + self.cell_8 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels, + in_channels_right=12*channels, out_channels_right=2*channels) + self.cell_9 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels, + in_channels_right=12*channels, out_channels_right=2*channels) + self.cell_10 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels, + in_channels_right=12*channels, out_channels_right=2*channels) + self.cell_11 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels, + in_channels_right=12*channels, out_channels_right=2*channels) + + self.reduction_cell_1 = ReductionCell1(in_channels_left=12*channels, out_channels_left=4*channels, + in_channels_right=12*channels, out_channels_right=4*channels) + + self.cell_12 = FirstCell(in_channels_left=12*channels, out_channels_left=2*channels, + in_channels_right=16*channels, out_channels_right=4*channels) + self.cell_13 = NormalCell(in_channels_left=16*channels, out_channels_left=4*channels, + in_channels_right=24*channels, out_channels_right=4*channels) + self.cell_14 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels, + in_channels_right=24*channels, out_channels_right=4*channels) + self.cell_15 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels, + in_channels_right=24*channels, out_channels_right=4*channels) + self.cell_16 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels, + in_channels_right=24*channels, out_channels_right=4*channels) + self.cell_17 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels, + in_channels_right=24*channels, out_channels_right=4*channels) + + self.relu = nn.ReLU() + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + + def forward_features(self, input, pool=True): + x_conv0 = self.conv0(input) + x_stem_0 = self.cell_stem_0(x_conv0) + x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0) + + x_cell_0 = self.cell_0(x_stem_1, x_stem_0) + x_cell_1 = self.cell_1(x_cell_0, x_stem_1) + x_cell_2 = self.cell_2(x_cell_1, x_cell_0) + x_cell_3 = self.cell_3(x_cell_2, x_cell_1) + x_cell_4 = self.cell_4(x_cell_3, x_cell_2) + x_cell_5 = self.cell_5(x_cell_4, x_cell_3) + + x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4) + + x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4) + x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0) + x_cell_8 = self.cell_8(x_cell_7, x_cell_6) + x_cell_9 = self.cell_9(x_cell_8, x_cell_7) + x_cell_10 = self.cell_10(x_cell_9, x_cell_8) + x_cell_11 = self.cell_11(x_cell_10, x_cell_9) + + x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10) + + x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10) + x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1) + x_cell_14 = self.cell_14(x_cell_13, x_cell_12) + x_cell_15 = self.cell_15(x_cell_14, x_cell_13) + x_cell_16 = self.cell_16(x_cell_15, x_cell_14) + x_cell_17 = self.cell_17(x_cell_16, x_cell_15) + x = self.relu(x_cell_17) + if pool: + x = self.global_pool(x) + x = x.view(x.size(0), -1) + return x + + def forward(self, input): + x = self.forward_features(input) + if self.drop_rate > 0: + x = F.dropout(x, self.drop_rate, training=self.training) + x = self.last_linear(x) + return x + + +@register_model +def nasnetalarge(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """NASNet-A large model architecture. + """ + default_cfg = default_cfgs['nasnetalarge'] + model = NASNetALarge(num_classes=1000, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + + return model