From 0bc50e84f8c12529d781c6c46d3c927f335cf349 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 10 Apr 2019 14:12:28 -0700 Subject: [PATCH] Lots of refactoring and cleanup. * Move 'test time pool' to Module that can be used by any model, remove from DPN * Remove ResNext model file and combine with ResNet * Remove fbresnet200 as it was an old conversion and pretrained performance not worth param count * Cleanup adaptive avgmax pooling and add back conctat variant * Factor out checkpoint load fn --- data/utils.py | 7 +- inference.py | 26 +- models/__init__.py | 3 +- models/adaptive_avgmax_pool.py | 66 +- models/densenet.py | 3 +- models/dpn.py | 43 +- models/fbresnet200.py | 1254 -------------------------------- models/inception_resnet_v2.py | 2 +- models/inception_v4.py | 2 +- models/model_factory.py | 62 +- models/resnet.py | 102 ++- models/resnext.py | 175 ----- models/senet.py | 4 +- models/test_time_pool.py | 27 + validate.py | 42 +- 15 files changed, 244 insertions(+), 1574 deletions(-) delete mode 100644 models/fbresnet200.py delete mode 100644 models/resnext.py create mode 100644 models/test_time_pool.py diff --git a/data/utils.py b/data/utils.py index 836df37f..964f4812 100644 --- a/data/utils.py +++ b/data/utils.py @@ -39,8 +39,7 @@ class PrefetchLoader: with torch.cuda.stream(stream): next_input = next_input.cuda(non_blocking=True) next_target = next_target.cuda(non_blocking=True) - next_input = next_input.float() - next_input = next_input.sub_(self.mean).div_(self.std) + next_input = next_input.float().sub_(self.mean).div_(self.std) if self.random_erasing is not None: next_input = self.random_erasing(next_input) @@ -74,6 +73,7 @@ def create_loader( std=IMAGENET_DEFAULT_STD, num_workers=1, distributed=False, + crop_pct=None, ): if is_training: @@ -87,7 +87,8 @@ def create_loader( img_size, use_prefetcher=use_prefetcher, mean=mean, - std=std) + std=std, + crop_pct=crop_pct) dataset.transform = transform diff --git a/inference.py b/inference.py index d6d4d385..8b696090 100644 --- a/inference.py +++ b/inference.py @@ -11,7 +11,7 @@ import argparse import numpy as np import torch -from models import create_model +from models import create_model, load_checkpoint, TestTimePoolHead from data import Dataset, create_loader, get_model_meanstd from utils import AverageMeter @@ -49,21 +49,15 @@ def main(): model = create_model( args.model, num_classes=num_classes, - pretrained=args.pretrained, - test_time_pool=args.test_time_pool) - - # resume from a checkpoint - if args.checkpoint and os.path.isfile(args.checkpoint): - print("=> loading checkpoint '{}'".format(args.checkpoint)) - checkpoint = torch.load(args.checkpoint) - if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - model.load_state_dict(checkpoint['state_dict']) - else: - model.load_state_dict(checkpoint) - print("=> loaded checkpoint '{}'".format(args.checkpoint)) - elif not args.pretrained: - print("=> no checkpoint found at '{}'".format(args.checkpoint)) - exit(1) + pretrained=args.pretrained) + + print('Model %s created, param count: %d' % + (args.model, sum([m.numel() for m in model.parameters()]))) + + # load a checkpoint + if not args.pretrained: + if not load_checkpoint(model, args.checkpoint): + exit(1) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() diff --git a/models/__init__.py b/models/__init__.py index b24eceb5..e975d08d 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,2 +1,3 @@ -from .model_factory import create_model +from .model_factory import create_model, load_checkpoint +from .test_time_pool import TestTimePoolHead diff --git a/models/adaptive_avgmax_pool.py b/models/adaptive_avgmax_pool.py index 01fcb4ae..2672fb0c 100644 --- a/models/adaptive_avgmax_pool.py +++ b/models/adaptive_avgmax_pool.py @@ -14,29 +14,70 @@ import torch.nn as nn import torch.nn.functional as F -def adaptive_avgmax_pool2d(x, pool_type='avg', output_size=1): +def adaptive_pool_feat_mult(pool_type='avg'): + if pool_type == 'catavgmax': + return 2 + else: + return 1 + + +def adaptive_avgmax_pool2d(x, output_size=1): + x_avg = F.adaptive_avg_pool2d(x, output_size) + x_max = F.adaptive_max_pool2d(x, output_size) + return 0.5 * (x_avg + x_max) + + +def adaptive_catavgmax_pool2d(x, output_size=1): + x_avg = F.adaptive_avg_pool2d(x, output_size) + x_max = F.adaptive_max_pool2d(x, output_size) + return torch.cat((x_avg, x_max), 1) + + +def select_adaptive_pool2d(x, pool_type='avg', output_size=1): """Selectable global pooling function with dynamic input kernel size """ - if pool_type == 'avgmax': - x_avg = F.adaptive_avg_pool2d(x, output_size) - x_max = F.adaptive_max_pool2d(x, output_size) - x = 0.5 * (x_avg + x_max) + if pool_type == 'avg': + x = F.adaptive_avg_pool2d(x, output_size) + elif pool_type == 'avgmax': + x = adaptive_avgmax_pool2d(x, output_size) + elif pool_type == 'catavgmax': + x = adaptive_catavgmax_pool2d(x, output_size) elif pool_type == 'max': x = F.adaptive_max_pool2d(x, output_size) else: - x = F.adaptive_avg_pool2d(x, output_size) + assert False, 'Invalid pool type: %s' % pool_type return x class AdaptiveAvgMaxPool2d(torch.nn.Module): + def __init__(self, output_size=1): + super(AdaptiveAvgMaxPool2d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_avgmax_pool2d(x, self.output_size) + + +class AdaptiveCatAvgMaxPool2d(torch.nn.Module): + def __init__(self, output_size=1): + super(AdaptiveCatAvgMaxPool2d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_catavgmax_pool2d(x, self.output_size) + + +class SelectAdaptivePool2d(torch.nn.Module): """Selectable global pooling layer with dynamic input kernel size """ def __init__(self, output_size=1, pool_type='avg'): - super(AdaptiveAvgMaxPool2d, self).__init__() + super(SelectAdaptivePool2d, self).__init__() self.output_size = output_size self.pool_type = pool_type if pool_type == 'avgmax': - self.pool = nn.ModuleList([nn.AdaptiveAvgPool2d(output_size), nn.AdaptiveMaxPool2d(output_size)]) + self.pool = AdaptiveAvgMaxPool2d(output_size) + elif pool_type == 'catavgmax': + self.pool = AdaptiveCatAvgMaxPool2d(output_size) elif pool_type == 'max': self.pool = nn.AdaptiveMaxPool2d(output_size) else: @@ -45,11 +86,10 @@ class AdaptiveAvgMaxPool2d(torch.nn.Module): self.pool = nn.AdaptiveAvgPool2d(output_size) def forward(self, x): - if self.pool_type == 'avgmax': - x = 0.5 * torch.sum(torch.stack([p(x) for p in self.pool]), 0).squeeze(dim=0) - else: - x = self.pool(x) - return x + return self.pool(x) + + def feat_mult(self): + return adaptive_pool_feat_mult(self.pool_type) def __repr__(self): return self.__class__.__name__ + ' (' \ diff --git a/models/densenet.py b/models/densenet.py index 9a63533f..46f8f7b9 100644 --- a/models/densenet.py +++ b/models/densenet.py @@ -83,7 +83,6 @@ def densenet161(pretrained=False, **kwargs): Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - print(kwargs) model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs) if pretrained: state_dict = model_zoo.load_url(model_urls['densenet161']) @@ -193,7 +192,7 @@ class DenseNet(nn.Module): x = self.features(x) x = F.relu(x, inplace=True) if pool: - x = adaptive_avgmax_pool2d(x, self.global_pool) + x = select_adaptive_pool2d(x, self.global_pool) x = x.view(x.size(0), -1) return x diff --git a/models/dpn.py b/models/dpn.py index e8f84da2..ec8fe9d2 100644 --- a/models/dpn.py +++ b/models/dpn.py @@ -16,7 +16,7 @@ import torch.nn.functional as F import torch.utils.model_zoo as model_zoo from collections import OrderedDict -from .adaptive_avgmax_pool import adaptive_avgmax_pool2d +from .adaptive_avgmax_pool import select_adaptive_pool2d __all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107'] @@ -41,31 +41,31 @@ model_urls = { } -def dpn68(num_classes=1000, pretrained=False, test_time_pool=7): +def dpn68(num_classes=1000, pretrained=False): model = DPN( small=True, num_init_features=10, k_r=128, groups=32, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), - num_classes=num_classes, test_time_pool=test_time_pool) + num_classes=num_classes) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['dpn68'])) return model -def dpn68b(num_classes=1000, pretrained=False, test_time_pool=7): +def dpn68b(num_classes=1000, pretrained=False): model = DPN( small=True, num_init_features=10, k_r=128, groups=32, b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), - num_classes=num_classes, test_time_pool=test_time_pool) + num_classes=num_classes) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['dpn68b_extra'])) return model -def dpn92(num_classes=1000, pretrained=False, test_time_pool=7, extra=True): +def dpn92(num_classes=1000, pretrained=False, extra=True): model = DPN( num_init_features=64, k_r=96, groups=32, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), - num_classes=num_classes, test_time_pool=test_time_pool) + num_classes=num_classes) if pretrained: if extra: model.load_state_dict(model_zoo.load_url(model_urls['dpn92_extra'])) @@ -74,31 +74,31 @@ def dpn92(num_classes=1000, pretrained=False, test_time_pool=7, extra=True): return model -def dpn98(num_classes=1000, pretrained=False, test_time_pool=7): +def dpn98(num_classes=1000, pretrained=False): model = DPN( num_init_features=96, k_r=160, groups=40, k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), - num_classes=num_classes, test_time_pool=test_time_pool) + num_classes=num_classes) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['dpn98'])) return model -def dpn131(num_classes=1000, pretrained=False, test_time_pool=7): +def dpn131(num_classes=1000, pretrained=False): model = DPN( num_init_features=128, k_r=160, groups=40, k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), - num_classes=num_classes, test_time_pool=test_time_pool) + num_classes=num_classes) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['dpn131'])) return model -def dpn107(num_classes=1000, pretrained=False, test_time_pool=7): +def dpn107(num_classes=1000, pretrained=False): model = DPN( num_init_features=128, k_r=200, groups=50, k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), - num_classes=num_classes, test_time_pool=test_time_pool) + num_classes=num_classes) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['dpn107_extra'])) return model @@ -212,10 +212,9 @@ class DualPathBlock(nn.Module): class DPN(nn.Module): def __init__(self, small=False, num_init_features=64, k_r=96, groups=32, b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), - num_classes=1000, test_time_pool=0, fc_act=nn.ELU(inplace=True)): + num_classes=1000, fc_act=nn.ELU(inplace=True)): super(DPN, self).__init__() self.num_classes = num_classes - self.test_time_pool = test_time_pool self.b = b bw_factor = 1 if small else 4 @@ -287,20 +286,12 @@ class DPN(nn.Module): def forward_features(self, x, pool=True): x = self.features(x) if pool: - x = adaptive_avgmax_pool2d(x, pool_type='avg') - x = x.view(x.size(0), -1) + x = select_adaptive_pool2d(x, pool_type='avg') return x def forward(self, x): - x = self.features(x) - if not self.training and self.test_time_pool: - x = F.avg_pool2d(x, kernel_size=self.test_time_pool, stride=1) - out = self.classifier(x) - # The extra test time pool should be pooling an img_size//32 - 6 size patch - out = adaptive_avgmax_pool2d(out, pool_type='avgmax') - else: - x = adaptive_avgmax_pool2d(x, pool_type='avg') - out = self.classifier(x) + x = self.forward_features(x) + out = self.classifier(x) return out.view(out.size(0), -1) diff --git a/models/fbresnet200.py b/models/fbresnet200.py deleted file mode 100644 index 00da02b9..00000000 --- a/models/fbresnet200.py +++ /dev/null @@ -1,1254 +0,0 @@ -"""Facebook ResNet-200 Torch Model -Model with weights ported from https://github.com/facebook/fb.resnet.torch (BSD-3-Clause) -using https://github.com/clcarwin/convert_torch_to_pytorch (MIT) -""" -import torch -import torch.nn as nn -import torch.nn.init as init -import torch.nn.functional as F -import torch.utils.model_zoo as model_zoo -from torch.autograd import Variable -from functools import reduce -from collections import OrderedDict -from .adaptive_avgmax_pool import * - -model_urls = { - 'fbresnet200': 'https://www.dropbox.com/s/tchq8fbdd4wabjx/fbresnet_200-37304a01b.pth?dl=1', -} - -class LambdaBase(nn.Sequential): - def __init__(self, fn, *args): - super(LambdaBase, self).__init__(*args) - self.lambda_func = fn - - def forward_prepare(self, input): - output = [] - for module in self._modules.values(): - output.append(module(input)) - return output if output else input - - -class Lambda(LambdaBase): - def forward(self, input): - return self.lambda_func(self.forward_prepare(input)) - - -class LambdaMap(LambdaBase): - def forward(self, input): - return list(map(self.lambda_func, self.forward_prepare(input))) - - -class LambdaReduce(LambdaBase): - def forward(self, input): - return reduce(self.lambda_func, self.forward_prepare(input)) - - -def fbresnet200_features(activation_fn=nn.ReLU()): - return nn.Sequential( # Sequential, - nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3)), - nn.BatchNorm2d(64), - activation_fn, - nn.MaxPool2d((3, 3), (2, 2), (1, 1)), - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(64), - activation_fn, - nn.Conv2d(64, 64, (1, 1)), - nn.BatchNorm2d(64), - activation_fn, - nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(64), - activation_fn, - nn.Conv2d(64, 256, (1, 1)), - ), - nn.Sequential( # Sequential, - nn.Conv2d(64, 256, (1, 1)), - nn.BatchNorm2d(256), - ), - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 64, (1, 1)), - nn.BatchNorm2d(64), - activation_fn, - nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(64), - activation_fn, - nn.Conv2d(64, 256, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 64, (1, 1)), - nn.BatchNorm2d(64), - activation_fn, - nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(64), - activation_fn, - nn.Conv2d(64, 256, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - ), - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - nn.Sequential( # Sequential, - nn.Conv2d(256, 512, (1, 1), (2, 2)), - nn.BatchNorm2d(512), - ), - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 128, (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(128), - activation_fn, - nn.Conv2d(128, 512, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - ), - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - nn.Sequential( # Sequential, - nn.Conv2d(512, 1024, (1, 1), (2, 2)), - nn.BatchNorm2d(1024), - ), - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 256, (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(256), - activation_fn, - nn.Conv2d(256, 1024, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - ), - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(1024), - activation_fn, - nn.Conv2d(1024, 512, (1, 1)), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1)), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 2048, (1, 1)), - ), - nn.Sequential( # Sequential, - nn.Conv2d(1024, 2048, (1, 1), (2, 2)), - nn.BatchNorm2d(2048), - ), - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(2048), - activation_fn, - nn.Conv2d(2048, 512, (1, 1)), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 2048, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.BatchNorm2d(2048), - activation_fn, - nn.Conv2d(2048, 512, (1, 1)), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)), - nn.BatchNorm2d(512), - activation_fn, - nn.Conv2d(512, 2048, (1, 1)), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - ), - ), - Lambda(lambda x: x), # Copy, - nn.BatchNorm2d(2048), - activation_fn, - ) - - -class ResNet200(nn.Module): - - def __init__(self, num_classes=1000, activation_fn=nn.ReLU(), drop_rate=0., global_pool='avg'): - super(ResNet200, self).__init__() - self.drop_rate = drop_rate - self.global_pool = global_pool - self.num_classes = num_classes - self.num_features = 2048 - self.features = fbresnet200_features(activation_fn=activation_fn) - self.fc = nn.Linear(2048, num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - init.kaiming_normal(m.weight) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def get_classifier(self): - return self.fc - - def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = global_pool - self.num_classes = num_classes - del self.fc - self.fc = nn.Linear(2048, num_classes) - - def forward_features(self, x, pool=True): - x = self.features(x) - if pool: - x = adaptive_avgmax_pool2d(x, self.global_pool) - x = x.view(x.size(0), -1) - return x - - def forward(self, x): - x = self.forward_features(x) - if self.drop_rate > 0: - x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.fc(x) - return x - - -def fbresnet200(pretrained=False, num_classes=1000, **kwargs): - model = ResNet200(num_classes=num_classes, **kwargs) - if pretrained: - # Remap pretrained weights to match our class module with features + fc - pretrained_weights = model_zoo.load_url(model_urls['fbresnet200']) - feature_keys = filter(lambda k: '13.1.' not in k, pretrained_weights.keys()) - remapped_weights = OrderedDict() - for k in feature_keys: - remapped_weights['features.' + k] = pretrained_weights[k] - remapped_weights['fc.weight'] = pretrained_weights['13.1.weight'] - remapped_weights['fc.bias'] = pretrained_weights['13.1.bias'] - model.load_state_dict(remapped_weights) - return model diff --git a/models/inception_resnet_v2.py b/models/inception_resnet_v2.py index c8bcf208..fabd9731 100644 --- a/models/inception_resnet_v2.py +++ b/models/inception_resnet_v2.py @@ -305,7 +305,7 @@ class InceptionResnetV2(nn.Module): x = self.block8(x) x = self.conv2d_7b(x) if pool: - x = adaptive_avgmax_pool2d(x, self.global_pool) + x = select_adaptive_pool2d(x, self.global_pool) #x = F.avg_pool2d(x, 8, count_include_pad=False) x = x.view(x.size(0), -1) return x diff --git a/models/inception_v4.py b/models/inception_v4.py index bcb84661..3de774df 100644 --- a/models/inception_v4.py +++ b/models/inception_v4.py @@ -272,7 +272,7 @@ class InceptionV4(nn.Module): def forward_features(self, x, pool=True): x = self.features(x) if pool: - x = adaptive_avgmax_pool2d(x, self.global_pool, count_include_pad=False) + x = select_adaptive_pool2d(x, self.global_pool, count_include_pad=False) x = x.view(x.size(0), -1) return x diff --git a/models/model_factory.py b/models/model_factory.py index 68d6bd6d..e3d0e2f3 100644 --- a/models/model_factory.py +++ b/models/model_factory.py @@ -1,15 +1,16 @@ import torch import os +from collections import OrderedDict from .inception_v4 import inception_v4 from .inception_resnet_v2 import inception_resnet_v2 from .densenet import densenet161, densenet121, densenet169, densenet201 -from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152 -from .fbresnet200 import fbresnet200 +from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152, \ + resnext50_32x4d, resnext101_32x4d, resnext101_64x4d, resnext152_32x4d from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107 from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \ seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d -from .resnext import resnext50, resnext101, resnext152 +#from .resnext import resnext50, resnext101, resnext152 from .xception import xception model_config_dict = { @@ -57,26 +58,18 @@ def create_model( checkpoint_path='', **kwargs): - test_time_pool = kwargs.pop('test_time_pool') if 'test_time_pool' in kwargs else 0 - if model_name == 'dpn68': - model = dpn68( - num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) + model = dpn68(num_classes=num_classes, pretrained=pretrained) elif model_name == 'dpn68b': - model = dpn68b( - num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) + model = dpn68b(num_classes=num_classes, pretrained=pretrained) elif model_name == 'dpn92': - model = dpn92( - num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) + model = dpn92(num_classes=num_classes, pretrained=pretrained) elif model_name == 'dpn98': - model = dpn98( - num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) + model = dpn98(num_classes=num_classes, pretrained=pretrained) elif model_name == 'dpn131': - model = dpn131( - num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) + model = dpn131(num_classes=num_classes, pretrained=pretrained) elif model_name == 'dpn107': - model = dpn107( - num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) + model = dpn107(num_classes=num_classes, pretrained=pretrained) elif model_name == 'resnet18': model = resnet18(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'resnet34': @@ -99,8 +92,6 @@ def create_model( model = inception_resnet_v2(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'inception_v4': model = inception_v4(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'fbresnet200': - model = fbresnet200(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'seresnet18': model = seresnet18(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'seresnet34': @@ -117,12 +108,14 @@ def create_model( model = seresnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'seresnext101_32x4d': model = seresnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'resnext50': - model = resnext50(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'resnext101': - model = resnext101(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'resnext152': - model = resnext152(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'resnext50_32x4d': + model = resnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'resnext101_32x4d': + model = resnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'resnext101_64x4d': + model = resnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'resnext152_32x4d': + model = resnext152_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) elif model_name == 'xception': model = xception(num_classes=num_classes, pretrained=pretrained) else: @@ -136,13 +129,22 @@ def create_model( def load_checkpoint(model, checkpoint_path): - if checkpoint_path is not None and os.path.isfile(checkpoint_path): - print('Loading checkpoint', checkpoint_path) + if checkpoint_path and os.path.isfile(checkpoint_path): + print("=> Loading checkpoint '{}'".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - model.load_state_dict(checkpoint['state_dict']) + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + if k.startswith('module'): + name = k[7:] # remove `module.` + else: + name = k + new_state_dict[name] = v + model.load_state_dict(new_state_dict) else: model.load_state_dict(checkpoint) + print("=> Loaded checkpoint '{}'".format(checkpoint_path)) + return True else: - print("Error: No checkpoint found at %s." % checkpoint_path) - + print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) + return False diff --git a/models/resnet.py b/models/resnet.py index 743a5c80..08d836e7 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -7,7 +7,7 @@ import torch.nn as nn import torch.nn.functional as F import math import torch.utils.model_zoo as model_zoo -from .adaptive_avgmax_pool import AdaptiveAvgMaxPool2d +from .adaptive_avgmax_pool import SelectAdaptivePool2d __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] @@ -29,8 +29,13 @@ def conv3x3(in_planes, out_planes, stride=1): class BasicBlock(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0): + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, drop_rate=0.0): super(BasicBlock, self).__init__() + + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock doest not support changing base width' + self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) @@ -65,14 +70,18 @@ class BasicBlock(nn.Module): class Bottleneck(nn.Module): expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0): + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, drop_rate=0.0): super(Bottleneck, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + + width = int(math.floor(planes * (base_width / 64)) * cardinality) + + self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, + padding=1, groups=cardinality, bias=False) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.downsample = downsample @@ -108,10 +117,13 @@ class Bottleneck(nn.Module): class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000, + cardinality=1, base_width=64, drop_rate=0.0, block_drop_rate=0.0, global_pool='avg'): self.num_classes = num_classes self.inplanes = 64 + self.cardinality = cardinality + self.base_width = base_width self.drop_rate = drop_rate self.expansion = block.expansion super(ResNet, self).__init__() @@ -123,31 +135,29 @@ class ResNet(nn.Module): self.layer2 = self._make_layer(block, 128, layers[1], stride=2, drop_rate=block_drop_rate) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, drop_rate=block_drop_rate) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, drop_rate=block_drop_rate) - self.global_pool = AdaptiveAvgMaxPool2d(pool_type=global_pool) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_features = 512 * block.expansion - self.fc = nn.Linear(self.num_features, num_classes) + self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) def _make_layer(self, block, planes, blocks, stride=1, drop_rate=0.): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) - layers = [block(self.inplanes, planes, stride, downsample, drop_rate)] + layers = [block(self.inplanes, planes, stride, downsample, self.cardinality, self.base_width, drop_rate)] self.inplanes = planes * block.expansion for i in range(1, blocks): - layers.append(block(self.inplanes, planes)) + layers.append(block(self.inplanes, planes, cardinality=self.cardinality, base_width=self.base_width)) return nn.Sequential(*layers) @@ -155,11 +165,11 @@ class ResNet(nn.Module): return self.fc def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = AdaptiveAvgMaxPool2d(pool_type=global_pool) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes del self.fc if num_classes: - self.fc = nn.Linear(512 * self.expansion, num_classes) + self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) else: self.fc = None @@ -244,4 +254,52 @@ def resnet152(pretrained=False, **kwargs): model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) - return model \ No newline at end of file + return model + + +def resnext50_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs): + """Constructs a ResNeXt50-32x4d model. + + Args: + cardinality (int): Cardinality of the aggregated transform + base_width (int): Base width of the grouped convolution + """ + model = ResNet( + Bottleneck, [3, 4, 6, 3], cardinality=cardinality, base_width=base_width, **kwargs) + return model + + +def resnext101_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs): + """Constructs a ResNeXt-101 model. + + Args: + cardinality (int): Cardinality of the aggregated transform + base_width (int): Base width of the grouped convolution + """ + model = ResNet( + Bottleneck, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs) + return model + + +def resnext101_64x4d(cardinality=64, base_width=4, pretrained=False, **kwargs): + """Constructs a ResNeXt101-64x4d model. + + Args: + cardinality (int): Cardinality of the aggregated transform + base_width (int): Base width of the grouped convolution + """ + model = ResNet( + Bottleneck, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs) + return model + + +def resnext152_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs): + """Constructs a ResNeXt152-32x4d model. + + Args: + cardinality (int): Cardinality of the aggregated transform + base_width (int): Base width of the grouped convolution + """ + model = ResNet( + Bottleneck, [3, 8, 36, 3], cardinality=cardinality, base_width=base_width, **kwargs) + return model diff --git a/models/resnext.py b/models/resnext.py deleted file mode 100644 index aafcd93b..00000000 --- a/models/resnext.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F -import math -import torch.utils.model_zoo as model_zoo -from models.adaptive_avgmax_pool import AdaptiveAvgMaxPool2d - -__all__ = ['ResNeXt', 'resnext50', 'resnext101', 'resnext152'] - - -def conv3x3(in_planes, out_planes, stride=1): - "3x3 convolution with padding" - return nn.Conv2d( - in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) - - -class ResNeXtBottleneckC(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=32, base_width=4): - super(ResNeXtBottleneckC, self).__init__() - - width = math.floor(planes * (base_width / 64)) * cardinality - - self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(width) - self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, - padding=1, bias=False, groups=cardinality) - self.bn2 = nn.BatchNorm2d(width) - self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * 4) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class ResNeXt(nn.Module): - - def __init__(self, block, layers, num_classes=1000, cardinality=32, base_width=4, - drop_rate=0., global_pool='avg'): - self.num_classes = num_classes - self.inplanes = 64 - self.cardinality = cardinality - self.base_width = base_width - self.drop_rate = drop_rate - super(ResNeXt, self).__init__() - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) - self.bn1 = nn.BatchNorm2d(64) - self.relu = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2) - self.avgpool = AdaptiveAvgMaxPool2d(pool_type=global_pool) - self.num_features = 512 * block.expansion - self.fc = nn.Linear(self.num_features, num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1.) - nn.init.constant_(m.bias, 0.) - - def _make_layer(self, block, planes, blocks, stride=1): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d( - self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block.expansion), - ) - - layers = [block(self.inplanes, planes, stride, downsample, self.cardinality, self.base_width)] - self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block(self.inplanes, planes, 1, None, self.cardinality, self.base_width)) - - return nn.Sequential(*layers) - - def get_classifier(self): - return self.fc - - def reset_classifier(self, num_classes, global_pool='avg'): - self.avgpool = AdaptiveAvgMaxPool2d(pool_type=global_pool) - self.num_classes = num_classes - del self.fc - if num_classes: - self.fc = nn.Linear(self.num_features, num_classes) - else: - self.fc = None - - def forward_features(self, x, pool=True): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - if pool: - x = self.avgpool(x) - x = x.view(x.size(0), -1) - return x - - def forward(self, x): - x = self.forward_features(x) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.fc(x) - - return x - - -def resnext50(cardinality=32, base_width=4, pretrained=False, **kwargs): - """Constructs a ResNeXt-50 model. - - Args: - cardinality (int): Cardinality of the aggregated transform - base_width (int): Base width of the grouped convolution - shortcut ('A'|'B'|'C'): 'B' use 1x1 conv to downsample, 'C' use 1x1 conv on every residual connection - """ - model = ResNeXt( - ResNeXtBottleneckC, [3, 4, 6, 3], cardinality=cardinality, base_width=base_width, **kwargs) - return model - - -def resnext101(cardinality=32, base_width=4, pretrained=False, **kwargs): - """Constructs a ResNeXt-101 model. - - Args: - cardinality (int): Cardinality of the aggregated transform - base_width (int): Base width of the grouped convolution - shortcut ('A'|'B'|'C'): 'B' use 1x1 conv to downsample, 'C' use 1x1 conv on every residual connection - """ - model = ResNeXt( - ResNeXtBottleneckC, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs) - return model - - -def resnext152(cardinality=32, base_width=4, pretrained=False, **kwargs): - """Constructs a ResNeXt-152 model. - - Args: - cardinality (int): Cardinality of the aggregated transform - base_width (int): Base width of the grouped convolution - shortcut ('A'|'B'|'C'): 'B' use 1x1 conv to downsample, 'C' use 1x1 conv on every residual connection - """ - model = ResNeXt( - ResNeXtBottleneckC, [3, 8, 36, 3], cardinality=cardinality, base_width=base_width, **kwargs) - return model diff --git a/models/senet.py b/models/senet.py index e66082b8..bacec15f 100644 --- a/models/senet.py +++ b/models/senet.py @@ -9,7 +9,7 @@ import math import torch.nn as nn import torch.nn.functional as F from torch.utils import model_zoo -from models.adaptive_avgmax_pool import AdaptiveAvgMaxPool2d +from models.adaptive_avgmax_pool import SelectAdaptivePool2d __all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152', 'seresnext50_32x4d', 'seresnext101_32x4d'] @@ -307,7 +307,7 @@ class SENet(nn.Module): downsample_kernel_size=downsample_kernel_size, downsample_padding=downsample_padding ) - self.avg_pool = AdaptiveAvgMaxPool2d(pool_type=global_pool) + self.avg_pool = SelectAdaptivePool2d(pool_type=global_pool) self.drop_rate = drop_rate self.num_features = 512 * block.expansion self.last_linear = nn.Linear(self.num_features, num_classes) diff --git a/models/test_time_pool.py b/models/test_time_pool.py new file mode 100644 index 00000000..269f15f8 --- /dev/null +++ b/models/test_time_pool.py @@ -0,0 +1,27 @@ +from torch import nn +import torch.nn.functional as F +from models.adaptive_avgmax_pool import adaptive_avgmax_pool2d + + +class TestTimePoolHead(nn.Module): + def __init__(self, base, original_pool=7): + super(TestTimePoolHead, self).__init__() + self.base = base + self.original_pool = original_pool + base_fc = self.base.get_classifier() + if isinstance(base_fc, nn.Conv2d): + self.fc = base_fc + else: + self.fc = nn.Conv2d( + self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) + self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) + self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) + self.base.reset_classifier(0) # delete original fc layer + + def forward(self, x): + x = self.base.forward_features(x, pool=False) + x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) + x = self.fc(x) + x = adaptive_avgmax_pool2d(x, 1) + return x.view(x.size(0), -1) + diff --git a/validate.py b/validate.py index 9d21c873..1e82a1fc 100644 --- a/validate.py +++ b/validate.py @@ -9,9 +9,8 @@ import torch import torch.backends.cudnn as cudnn import torch.nn as nn import torch.nn.parallel -from collections import OrderedDict -from models import create_model +from models import create_model, load_checkpoint, TestTimePoolHead from data import Dataset, create_loader, get_model_meanstd @@ -41,40 +40,26 @@ parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', def main(): args = parser.parse_args() - test_time_pool = False - if 'dpn' in args.model and args.img_size > 224 and not args.no_test_pool: - test_time_pool = True - # create model num_classes = 1000 model = create_model( args.model, num_classes=num_classes, - pretrained=args.pretrained, - test_time_pool=test_time_pool) + pretrained=args.pretrained) print('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) - # optionally resume from a checkpoint - if args.checkpoint and os.path.isfile(args.checkpoint): - print("=> loading checkpoint '{}'".format(args.checkpoint)) - checkpoint = torch.load(args.checkpoint) - if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - new_state_dict = OrderedDict() - for k, v in checkpoint['state_dict'].items(): - if k.startswith('module'): - name = k[7:] # remove `module.` - else: - name = k - new_state_dict[name] = v - model.load_state_dict(new_state_dict) - else: - model.load_state_dict(checkpoint) - print("=> loaded checkpoint '{}'".format(args.checkpoint)) - elif not args.pretrained: - print("=> no checkpoint found at '{}'".format(args.checkpoint)) - exit(1) + # load a checkpoint + if not args.pretrained: + if not load_checkpoint(model, args.checkpoint): + exit(1) + + test_time_pool = False + # FIXME make this work for networks with default img size != 224 and default pool k != 7 + if args.img_size > 224 and not args.no_test_pool: + model = TestTimePoolHead(model) + test_time_pool = True if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() @@ -94,7 +79,8 @@ def main(): use_prefetcher=True, mean=data_mean, std=data_std, - num_workers=args.workers) + num_workers=args.workers, + crop_pct=1.0 if test_time_pool else None) batch_time = AverageMeter() losses = AverageMeter()