"""Facebook ResNet-200 Torch Model Model with weights ported from https://github.com/facebook/fb.resnet.torch (BSD-3-Clause) using https://github.com/clcarwin/convert_torch_to_pytorch (MIT) """ import torch import torch.nn as nn import torch.nn.init as init import torch.nn.functional as F import torch.utils.model_zoo as model_zoo from torch.autograd import Variable from functools import reduce from collections import OrderedDict from .adaptive_avgmax_pool import * model_urls = { 'fbresnet200': 'https://www.dropbox.com/s/tchq8fbdd4wabjx/fbresnet_200-37304a01b.pth?dl=1', } class LambdaBase(nn.Sequential): def __init__(self, fn, *args): super(LambdaBase, self).__init__(*args) self.lambda_func = fn def forward_prepare(self, input): output = [] for module in self._modules.values(): output.append(module(input)) return output if output else input class Lambda(LambdaBase): def forward(self, input): return self.lambda_func(self.forward_prepare(input)) class LambdaMap(LambdaBase): def forward(self, input): return list(map(self.lambda_func, self.forward_prepare(input))) class LambdaReduce(LambdaBase): def forward(self, input): return reduce(self.lambda_func, self.forward_prepare(input)) def fbresnet200_features(activation_fn=nn.ReLU()): return nn.Sequential( # Sequential, nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3)), nn.BatchNorm2d(64), activation_fn, nn.MaxPool2d((3, 3), (2, 2), (1, 1)), nn.Sequential( # Sequential, nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(64), activation_fn, nn.Conv2d(64, 64, (1, 1)), nn.BatchNorm2d(64), activation_fn, nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(64), activation_fn, nn.Conv2d(64, 256, (1, 1)), ), nn.Sequential( # Sequential, nn.Conv2d(64, 256, (1, 1)), nn.BatchNorm2d(256), ), ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 64, (1, 1)), nn.BatchNorm2d(64), activation_fn, nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(64), activation_fn, nn.Conv2d(64, 256, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 64, (1, 1)), nn.BatchNorm2d(64), activation_fn, nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(64), activation_fn, nn.Conv2d(64, 256, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), ), nn.Sequential( # Sequential, nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), nn.Sequential( # Sequential, nn.Conv2d(256, 512, (1, 1), (2, 2)), nn.BatchNorm2d(512), ), ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 128, (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(128), activation_fn, nn.Conv2d(128, 512, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), ), nn.Sequential( # Sequential, nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), nn.Sequential( # Sequential, nn.Conv2d(512, 1024, (1, 1), (2, 2)), nn.BatchNorm2d(1024), ), ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 256, (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(256), activation_fn, nn.Conv2d(256, 1024, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), ), nn.Sequential( # Sequential, nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(1024), activation_fn, nn.Conv2d(1024, 512, (1, 1)), nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1)), nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 2048, (1, 1)), ), nn.Sequential( # Sequential, nn.Conv2d(1024, 2048, (1, 1), (2, 2)), nn.BatchNorm2d(2048), ), ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(2048), activation_fn, nn.Conv2d(2048, 512, (1, 1)), nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 2048, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), nn.Sequential( # Sequential, LambdaMap(lambda x: x, # ConcatTable, nn.Sequential( # Sequential, nn.BatchNorm2d(2048), activation_fn, nn.Conv2d(2048, 512, (1, 1)), nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)), nn.BatchNorm2d(512), activation_fn, nn.Conv2d(512, 2048, (1, 1)), ), Lambda(lambda x: x), # Identity, ), LambdaReduce(lambda x, y: x + y), # CAddTable, ), ), Lambda(lambda x: x), # Copy, nn.BatchNorm2d(2048), activation_fn, ) class ResNet200(nn.Module): def __init__(self, num_classes=1000, activation_fn=nn.ReLU(), drop_rate=0., global_pool='avg'): super(ResNet200, self).__init__() self.drop_rate = drop_rate self.global_pool = global_pool self.num_classes = num_classes self.num_features = 2048 self.features = fbresnet200_features(activation_fn=activation_fn) self.fc = nn.Linear(2048, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal(m.weight) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def get_classifier(self): return self.fc def reset_classifier(self, num_classes, global_pool='avg'): self.global_pool = global_pool self.num_classes = num_classes del self.fc self.fc = nn.Linear(2048, num_classes) def forward_features(self, x, pool=True): x = self.features(x) if pool: x = adaptive_avgmax_pool2d(x, self.global_pool) x = x.view(x.size(0), -1) return x def forward(self, x): x = self.forward_features(x) if self.drop_rate > 0: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.fc(x) return x def fbresnet200(pretrained=False, num_classes=1000, **kwargs): model = ResNet200(num_classes=num_classes, **kwargs) if pretrained: # Remap pretrained weights to match our class module with features + fc pretrained_weights = model_zoo.load_url(model_urls['fbresnet200']) feature_keys = filter(lambda k: '13.1.' not in k, pretrained_weights.keys()) remapped_weights = OrderedDict() for k in feature_keys: remapped_weights['features.' + k] = pretrained_weights[k] remapped_weights['fc.weight'] = pretrained_weights['13.1.weight'] remapped_weights['fc.bias'] = pretrained_weights['13.1.bias'] model.load_state_dict(remapped_weights) return model