You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1255 lines
58 KiB
1255 lines
58 KiB
"""Facebook ResNet-200 Torch Model
|
|
Model with weights ported from https://github.com/facebook/fb.resnet.torch (BSD-3-Clause)
|
|
using https://github.com/clcarwin/convert_torch_to_pytorch (MIT)
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.init as init
|
|
import torch.nn.functional as F
|
|
import torch.utils.model_zoo as model_zoo
|
|
from torch.autograd import Variable
|
|
from functools import reduce
|
|
from collections import OrderedDict
|
|
from .adaptive_avgmax_pool import *
|
|
|
|
model_urls = {
|
|
'fbresnet200': 'https://www.dropbox.com/s/tchq8fbdd4wabjx/fbresnet_200-37304a01b.pth?dl=1',
|
|
}
|
|
|
|
class LambdaBase(nn.Sequential):
|
|
def __init__(self, fn, *args):
|
|
super(LambdaBase, self).__init__(*args)
|
|
self.lambda_func = fn
|
|
|
|
def forward_prepare(self, input):
|
|
output = []
|
|
for module in self._modules.values():
|
|
output.append(module(input))
|
|
return output if output else input
|
|
|
|
|
|
class Lambda(LambdaBase):
|
|
def forward(self, input):
|
|
return self.lambda_func(self.forward_prepare(input))
|
|
|
|
|
|
class LambdaMap(LambdaBase):
|
|
def forward(self, input):
|
|
return list(map(self.lambda_func, self.forward_prepare(input)))
|
|
|
|
|
|
class LambdaReduce(LambdaBase):
|
|
def forward(self, input):
|
|
return reduce(self.lambda_func, self.forward_prepare(input))
|
|
|
|
|
|
def fbresnet200_features(activation_fn=nn.ReLU()):
|
|
return nn.Sequential( # Sequential,
|
|
nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3)),
|
|
nn.BatchNorm2d(64),
|
|
activation_fn,
|
|
nn.MaxPool2d((3, 3), (2, 2), (1, 1)),
|
|
nn.Sequential( # Sequential,
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(64),
|
|
activation_fn,
|
|
nn.Conv2d(64, 64, (1, 1)),
|
|
nn.BatchNorm2d(64),
|
|
activation_fn,
|
|
nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(64),
|
|
activation_fn,
|
|
nn.Conv2d(64, 256, (1, 1)),
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
nn.Conv2d(64, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
),
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 64, (1, 1)),
|
|
nn.BatchNorm2d(64),
|
|
activation_fn,
|
|
nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(64),
|
|
activation_fn,
|
|
nn.Conv2d(64, 256, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 64, (1, 1)),
|
|
nn.BatchNorm2d(64),
|
|
activation_fn,
|
|
nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(64),
|
|
activation_fn,
|
|
nn.Conv2d(64, 256, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
nn.Conv2d(256, 512, (1, 1), (2, 2)),
|
|
nn.BatchNorm2d(512),
|
|
),
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 128, (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(128),
|
|
activation_fn,
|
|
nn.Conv2d(128, 512, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
nn.Conv2d(512, 1024, (1, 1), (2, 2)),
|
|
nn.BatchNorm2d(1024),
|
|
),
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 256, (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(256),
|
|
activation_fn,
|
|
nn.Conv2d(256, 1024, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(1024),
|
|
activation_fn,
|
|
nn.Conv2d(1024, 512, (1, 1)),
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1)),
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 2048, (1, 1)),
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
nn.Conv2d(1024, 2048, (1, 1), (2, 2)),
|
|
nn.BatchNorm2d(2048),
|
|
),
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(2048),
|
|
activation_fn,
|
|
nn.Conv2d(2048, 512, (1, 1)),
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 2048, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
nn.Sequential( # Sequential,
|
|
LambdaMap(lambda x: x, # ConcatTable,
|
|
nn.Sequential( # Sequential,
|
|
nn.BatchNorm2d(2048),
|
|
activation_fn,
|
|
nn.Conv2d(2048, 512, (1, 1)),
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)),
|
|
nn.BatchNorm2d(512),
|
|
activation_fn,
|
|
nn.Conv2d(512, 2048, (1, 1)),
|
|
),
|
|
Lambda(lambda x: x), # Identity,
|
|
),
|
|
LambdaReduce(lambda x, y: x + y), # CAddTable,
|
|
),
|
|
),
|
|
Lambda(lambda x: x), # Copy,
|
|
nn.BatchNorm2d(2048),
|
|
activation_fn,
|
|
)
|
|
|
|
|
|
class ResNet200(nn.Module):
|
|
|
|
def __init__(self, num_classes=1000, activation_fn=nn.ReLU(), drop_rate=0., global_pool='avg'):
|
|
super(ResNet200, self).__init__()
|
|
self.drop_rate = drop_rate
|
|
self.global_pool = global_pool
|
|
self.num_classes = num_classes
|
|
self.num_features = 2048
|
|
self.features = fbresnet200_features(activation_fn=activation_fn)
|
|
self.fc = nn.Linear(2048, num_classes)
|
|
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
init.kaiming_normal(m.weight)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
m.weight.data.fill_(1)
|
|
m.bias.data.zero_()
|
|
|
|
def get_classifier(self):
|
|
return self.fc
|
|
|
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
|
self.global_pool = global_pool
|
|
self.num_classes = num_classes
|
|
del self.fc
|
|
self.fc = nn.Linear(2048, num_classes)
|
|
|
|
def forward_features(self, x, pool=True):
|
|
x = self.features(x)
|
|
if pool:
|
|
x = adaptive_avgmax_pool2d(x, self.global_pool)
|
|
x = x.view(x.size(0), -1)
|
|
return x
|
|
|
|
def forward(self, x):
|
|
x = self.forward_features(x)
|
|
if self.drop_rate > 0:
|
|
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
|
|
def fbresnet200(pretrained=False, num_classes=1000, **kwargs):
|
|
model = ResNet200(num_classes=num_classes, **kwargs)
|
|
if pretrained:
|
|
# Remap pretrained weights to match our class module with features + fc
|
|
pretrained_weights = model_zoo.load_url(model_urls['fbresnet200'])
|
|
feature_keys = filter(lambda k: '13.1.' not in k, pretrained_weights.keys())
|
|
remapped_weights = OrderedDict()
|
|
for k in feature_keys:
|
|
remapped_weights['features.' + k] = pretrained_weights[k]
|
|
remapped_weights['fc.weight'] = pretrained_weights['13.1.weight']
|
|
remapped_weights['fc.bias'] = pretrained_weights['13.1.bias']
|
|
model.load_state_dict(remapped_weights)
|
|
return model
|