|
|
|
@ -2,10 +2,9 @@ import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
from .helpers import load_pretrained
|
|
|
|
|
from .layers import SelectAdaptivePool2d
|
|
|
|
|
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
__all__ = ['NASNetALarge']
|
|
|
|
|
|
|
|
|
@ -187,17 +186,17 @@ class CellStem1(nn.Module):
|
|
|
|
|
self.stem_size = stem_size
|
|
|
|
|
self.conv_1x1 = nn.Sequential()
|
|
|
|
|
self.conv_1x1.add_module('relu', nn.ReLU())
|
|
|
|
|
self.conv_1x1.add_module('conv', nn.Conv2d(2*self.num_channels, self.num_channels, 1, stride=1, bias=False))
|
|
|
|
|
self.conv_1x1.add_module('conv', nn.Conv2d(2 * self.num_channels, self.num_channels, 1, stride=1, bias=False))
|
|
|
|
|
self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True))
|
|
|
|
|
|
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
|
self.path_1 = nn.Sequential()
|
|
|
|
|
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
|
|
|
|
self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels//2, 1, stride=1, bias=False))
|
|
|
|
|
self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
|
|
|
|
|
self.path_2 = nn.ModuleList()
|
|
|
|
|
self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
|
|
|
|
|
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
|
|
|
|
self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels//2, 1, stride=1, bias=False))
|
|
|
|
|
self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
|
|
|
|
|
|
|
|
|
|
self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True)
|
|
|
|
|
|
|
|
|
@ -507,50 +506,50 @@ class NASNetALarge(nn.Module):
|
|
|
|
|
self.cell_stem_0 = CellStem0(self.stem_size, num_channels=channels // (channel_multiplier ** 2))
|
|
|
|
|
self.cell_stem_1 = CellStem1(self.stem_size, num_channels=channels // channel_multiplier)
|
|
|
|
|
|
|
|
|
|
self.cell_0 = FirstCell(in_channels_left=channels, out_channels_left=channels//2,
|
|
|
|
|
in_channels_right=2*channels, out_channels_right=channels)
|
|
|
|
|
self.cell_1 = NormalCell(in_channels_left=2*channels, out_channels_left=channels,
|
|
|
|
|
in_channels_right=6*channels, out_channels_right=channels)
|
|
|
|
|
self.cell_2 = NormalCell(in_channels_left=6*channels, out_channels_left=channels,
|
|
|
|
|
in_channels_right=6*channels, out_channels_right=channels)
|
|
|
|
|
self.cell_3 = NormalCell(in_channels_left=6*channels, out_channels_left=channels,
|
|
|
|
|
in_channels_right=6*channels, out_channels_right=channels)
|
|
|
|
|
self.cell_4 = NormalCell(in_channels_left=6*channels, out_channels_left=channels,
|
|
|
|
|
in_channels_right=6*channels, out_channels_right=channels)
|
|
|
|
|
self.cell_5 = NormalCell(in_channels_left=6*channels, out_channels_left=channels,
|
|
|
|
|
in_channels_right=6*channels, out_channels_right=channels)
|
|
|
|
|
|
|
|
|
|
self.reduction_cell_0 = ReductionCell0(in_channels_left=6*channels, out_channels_left=2*channels,
|
|
|
|
|
in_channels_right=6*channels, out_channels_right=2*channels)
|
|
|
|
|
|
|
|
|
|
self.cell_6 = FirstCell(in_channels_left=6*channels, out_channels_left=channels,
|
|
|
|
|
in_channels_right=8*channels, out_channels_right=2*channels)
|
|
|
|
|
self.cell_7 = NormalCell(in_channels_left=8*channels, out_channels_left=2*channels,
|
|
|
|
|
in_channels_right=12*channels, out_channels_right=2*channels)
|
|
|
|
|
self.cell_8 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels,
|
|
|
|
|
in_channels_right=12*channels, out_channels_right=2*channels)
|
|
|
|
|
self.cell_9 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels,
|
|
|
|
|
in_channels_right=12*channels, out_channels_right=2*channels)
|
|
|
|
|
self.cell_10 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels,
|
|
|
|
|
in_channels_right=12*channels, out_channels_right=2*channels)
|
|
|
|
|
self.cell_11 = NormalCell(in_channels_left=12*channels, out_channels_left=2*channels,
|
|
|
|
|
in_channels_right=12*channels, out_channels_right=2*channels)
|
|
|
|
|
|
|
|
|
|
self.reduction_cell_1 = ReductionCell1(in_channels_left=12*channels, out_channels_left=4*channels,
|
|
|
|
|
in_channels_right=12*channels, out_channels_right=4*channels)
|
|
|
|
|
|
|
|
|
|
self.cell_12 = FirstCell(in_channels_left=12*channels, out_channels_left=2*channels,
|
|
|
|
|
in_channels_right=16*channels, out_channels_right=4*channels)
|
|
|
|
|
self.cell_13 = NormalCell(in_channels_left=16*channels, out_channels_left=4*channels,
|
|
|
|
|
in_channels_right=24*channels, out_channels_right=4*channels)
|
|
|
|
|
self.cell_14 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels,
|
|
|
|
|
in_channels_right=24*channels, out_channels_right=4*channels)
|
|
|
|
|
self.cell_15 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels,
|
|
|
|
|
in_channels_right=24*channels, out_channels_right=4*channels)
|
|
|
|
|
self.cell_16 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels,
|
|
|
|
|
in_channels_right=24*channels, out_channels_right=4*channels)
|
|
|
|
|
self.cell_17 = NormalCell(in_channels_left=24*channels, out_channels_left=4*channels,
|
|
|
|
|
in_channels_right=24*channels, out_channels_right=4*channels)
|
|
|
|
|
self.cell_0 = FirstCell(in_channels_left=channels, out_channels_left=channels // 2,
|
|
|
|
|
in_channels_right=2 * channels, out_channels_right=channels)
|
|
|
|
|
self.cell_1 = NormalCell(in_channels_left=2 * channels, out_channels_left=channels,
|
|
|
|
|
in_channels_right=6 * channels, out_channels_right=channels)
|
|
|
|
|
self.cell_2 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
|
|
|
|
|
in_channels_right=6 * channels, out_channels_right=channels)
|
|
|
|
|
self.cell_3 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
|
|
|
|
|
in_channels_right=6 * channels, out_channels_right=channels)
|
|
|
|
|
self.cell_4 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
|
|
|
|
|
in_channels_right=6 * channels, out_channels_right=channels)
|
|
|
|
|
self.cell_5 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
|
|
|
|
|
in_channels_right=6 * channels, out_channels_right=channels)
|
|
|
|
|
|
|
|
|
|
self.reduction_cell_0 = ReductionCell0(in_channels_left=6 * channels, out_channels_left=2 * channels,
|
|
|
|
|
in_channels_right=6 * channels, out_channels_right=2 * channels)
|
|
|
|
|
|
|
|
|
|
self.cell_6 = FirstCell(in_channels_left=6 * channels, out_channels_left=channels,
|
|
|
|
|
in_channels_right=8 * channels, out_channels_right=2 * channels)
|
|
|
|
|
self.cell_7 = NormalCell(in_channels_left=8 * channels, out_channels_left=2 * channels,
|
|
|
|
|
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
|
|
|
|
self.cell_8 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
|
|
|
|
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
|
|
|
|
self.cell_9 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
|
|
|
|
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
|
|
|
|
self.cell_10 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
|
|
|
|
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
|
|
|
|
self.cell_11 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
|
|
|
|
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
|
|
|
|
|
|
|
|
|
self.reduction_cell_1 = ReductionCell1(in_channels_left=12 * channels, out_channels_left=4 * channels,
|
|
|
|
|
in_channels_right=12 * channels, out_channels_right=4 * channels)
|
|
|
|
|
|
|
|
|
|
self.cell_12 = FirstCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
|
|
|
|
in_channels_right=16 * channels, out_channels_right=4 * channels)
|
|
|
|
|
self.cell_13 = NormalCell(in_channels_left=16 * channels, out_channels_left=4 * channels,
|
|
|
|
|
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
|
|
|
|
self.cell_14 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
|
|
|
|
|
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
|
|
|
|
self.cell_15 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
|
|
|
|
|
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
|
|
|
|
self.cell_16 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
|
|
|
|
|
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
|
|
|
|
self.cell_17 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
|
|
|
|
|
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
|
|
|
|
|
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
@ -562,9 +561,11 @@ class NASNetALarge(nn.Module):
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
|
del self.last_linear
|
|
|
|
|
self.last_linear = nn.Linear(
|
|
|
|
|
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
|
|
|
|
|
if num_classes:
|
|
|
|
|
num_features = self.num_features * self.global_pool.feat_mult()
|
|
|
|
|
self.last_linear = nn.Linear(num_features, num_classes)
|
|
|
|
|
else:
|
|
|
|
|
self.last_linear = nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
x_conv0 = self.conv0(x)
|
|
|
|
|