You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/models/fbresnet200.py

1255 lines
58 KiB

"""Facebook ResNet-200 Torch Model
Model with weights ported from https://github.com/facebook/fb.resnet.torch (BSD-3-Clause)
using https://github.com/clcarwin/convert_torch_to_pytorch (MIT)
"""
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.autograd import Variable
from functools import reduce
from collections import OrderedDict
from .adaptive_avgmax_pool import *
model_urls = {
'fbresnet200': 'https://www.dropbox.com/s/tchq8fbdd4wabjx/fbresnet_200-37304a01b.pth?dl=1',
}
class LambdaBase(nn.Sequential):
def __init__(self, fn, *args):
super(LambdaBase, self).__init__(*args)
self.lambda_func = fn
def forward_prepare(self, input):
output = []
for module in self._modules.values():
output.append(module(input))
return output if output else input
class Lambda(LambdaBase):
def forward(self, input):
return self.lambda_func(self.forward_prepare(input))
class LambdaMap(LambdaBase):
def forward(self, input):
return list(map(self.lambda_func, self.forward_prepare(input)))
class LambdaReduce(LambdaBase):
def forward(self, input):
return reduce(self.lambda_func, self.forward_prepare(input))
def fbresnet200_features(activation_fn=nn.ReLU()):
return nn.Sequential( # Sequential,
nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3)),
nn.BatchNorm2d(64),
activation_fn,
nn.MaxPool2d((3, 3), (2, 2), (1, 1)),
nn.Sequential( # Sequential,
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(64),
activation_fn,
nn.Conv2d(64, 64, (1, 1)),
nn.BatchNorm2d(64),
activation_fn,
nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(64),
activation_fn,
nn.Conv2d(64, 256, (1, 1)),
),
nn.Sequential( # Sequential,
nn.Conv2d(64, 256, (1, 1)),
nn.BatchNorm2d(256),
),
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 64, (1, 1)),
nn.BatchNorm2d(64),
activation_fn,
nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(64),
activation_fn,
nn.Conv2d(64, 256, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 64, (1, 1)),
nn.BatchNorm2d(64),
activation_fn,
nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(64),
activation_fn,
nn.Conv2d(64, 256, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
),
nn.Sequential( # Sequential,
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
nn.Sequential( # Sequential,
nn.Conv2d(256, 512, (1, 1), (2, 2)),
nn.BatchNorm2d(512),
),
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 128, (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(128),
activation_fn,
nn.Conv2d(128, 512, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
),
nn.Sequential( # Sequential,
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
nn.Sequential( # Sequential,
nn.Conv2d(512, 1024, (1, 1), (2, 2)),
nn.BatchNorm2d(1024),
),
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 256, (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(256),
activation_fn,
nn.Conv2d(256, 1024, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
),
nn.Sequential( # Sequential,
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(1024),
activation_fn,
nn.Conv2d(1024, 512, (1, 1)),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1)),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 2048, (1, 1)),
),
nn.Sequential( # Sequential,
nn.Conv2d(1024, 2048, (1, 1), (2, 2)),
nn.BatchNorm2d(2048),
),
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(2048),
activation_fn,
nn.Conv2d(2048, 512, (1, 1)),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 2048, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.BatchNorm2d(2048),
activation_fn,
nn.Conv2d(2048, 512, (1, 1)),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)),
nn.BatchNorm2d(512),
activation_fn,
nn.Conv2d(512, 2048, (1, 1)),
),
Lambda(lambda x: x), # Identity,
),
LambdaReduce(lambda x, y: x + y), # CAddTable,
),
),
Lambda(lambda x: x), # Copy,
nn.BatchNorm2d(2048),
activation_fn,
)
class ResNet200(nn.Module):
def __init__(self, num_classes=1000, activation_fn=nn.ReLU(), drop_rate=0., global_pool='avg'):
super(ResNet200, self).__init__()
self.drop_rate = drop_rate
self.global_pool = global_pool
self.num_classes = num_classes
self.num_features = 2048
self.features = fbresnet200_features(activation_fn=activation_fn)
self.fc = nn.Linear(2048, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = global_pool
self.num_classes = num_classes
del self.fc
self.fc = nn.Linear(2048, num_classes)
def forward_features(self, x, pool=True):
x = self.features(x)
if pool:
x = adaptive_avgmax_pool2d(x, self.global_pool)
x = x.view(x.size(0), -1)
return x
def forward(self, x):
x = self.forward_features(x)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
return x
def fbresnet200(pretrained=False, num_classes=1000, **kwargs):
model = ResNet200(num_classes=num_classes, **kwargs)
if pretrained:
# Remap pretrained weights to match our class module with features + fc
pretrained_weights = model_zoo.load_url(model_urls['fbresnet200'])
feature_keys = filter(lambda k: '13.1.' not in k, pretrained_weights.keys())
remapped_weights = OrderedDict()
for k in feature_keys:
remapped_weights['features.' + k] = pretrained_weights[k]
remapped_weights['fc.weight'] = pretrained_weights['13.1.weight']
remapped_weights['fc.bias'] = pretrained_weights['13.1.bias']
model.load_state_dict(remapped_weights)
return model