Lots of refactoring and cleanup.

* Move 'test time pool' to Module that can be used by any model, remove from DPN
* Remove ResNext model file and combine with ResNet
* Remove fbresnet200 as it was an old conversion and pretrained performance not worth param count
* Cleanup adaptive avgmax pooling and add back conctat variant
* Factor out checkpoint load fn
pull/1/head
Ross Wightman 5 years ago
parent f2029dfb65
commit 0bc50e84f8

@ -39,8 +39,7 @@ class PrefetchLoader:
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True) next_input = next_input.cuda(non_blocking=True)
next_target = next_target.cuda(non_blocking=True) next_target = next_target.cuda(non_blocking=True)
next_input = next_input.float() next_input = next_input.float().sub_(self.mean).div_(self.std)
next_input = next_input.sub_(self.mean).div_(self.std)
if self.random_erasing is not None: if self.random_erasing is not None:
next_input = self.random_erasing(next_input) next_input = self.random_erasing(next_input)
@ -74,6 +73,7 @@ def create_loader(
std=IMAGENET_DEFAULT_STD, std=IMAGENET_DEFAULT_STD,
num_workers=1, num_workers=1,
distributed=False, distributed=False,
crop_pct=None,
): ):
if is_training: if is_training:
@ -87,7 +87,8 @@ def create_loader(
img_size, img_size,
use_prefetcher=use_prefetcher, use_prefetcher=use_prefetcher,
mean=mean, mean=mean,
std=std) std=std,
crop_pct=crop_pct)
dataset.transform = transform dataset.transform = transform

@ -11,7 +11,7 @@ import argparse
import numpy as np import numpy as np
import torch import torch
from models import create_model from models import create_model, load_checkpoint, TestTimePoolHead
from data import Dataset, create_loader, get_model_meanstd from data import Dataset, create_loader, get_model_meanstd
from utils import AverageMeter from utils import AverageMeter
@ -49,21 +49,15 @@ def main():
model = create_model( model = create_model(
args.model, args.model,
num_classes=num_classes, num_classes=num_classes,
pretrained=args.pretrained, pretrained=args.pretrained)
test_time_pool=args.test_time_pool)
print('Model %s created, param count: %d' %
# resume from a checkpoint (args.model, sum([m.numel() for m in model.parameters()])))
if args.checkpoint and os.path.isfile(args.checkpoint):
print("=> loading checkpoint '{}'".format(args.checkpoint)) # load a checkpoint
checkpoint = torch.load(args.checkpoint) if not args.pretrained:
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if not load_checkpoint(model, args.checkpoint):
model.load_state_dict(checkpoint['state_dict']) exit(1)
else:
model.load_state_dict(checkpoint)
print("=> loaded checkpoint '{}'".format(args.checkpoint))
elif not args.pretrained:
print("=> no checkpoint found at '{}'".format(args.checkpoint))
exit(1)
if args.num_gpu > 1: if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()

@ -1,2 +1,3 @@
from .model_factory import create_model from .model_factory import create_model, load_checkpoint
from .test_time_pool import TestTimePoolHead

@ -14,29 +14,70 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
def adaptive_avgmax_pool2d(x, pool_type='avg', output_size=1): def adaptive_pool_feat_mult(pool_type='avg'):
if pool_type == 'catavgmax':
return 2
else:
return 1
def adaptive_avgmax_pool2d(x, output_size=1):
x_avg = F.adaptive_avg_pool2d(x, output_size)
x_max = F.adaptive_max_pool2d(x, output_size)
return 0.5 * (x_avg + x_max)
def adaptive_catavgmax_pool2d(x, output_size=1):
x_avg = F.adaptive_avg_pool2d(x, output_size)
x_max = F.adaptive_max_pool2d(x, output_size)
return torch.cat((x_avg, x_max), 1)
def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
"""Selectable global pooling function with dynamic input kernel size """Selectable global pooling function with dynamic input kernel size
""" """
if pool_type == 'avgmax': if pool_type == 'avg':
x_avg = F.adaptive_avg_pool2d(x, output_size) x = F.adaptive_avg_pool2d(x, output_size)
x_max = F.adaptive_max_pool2d(x, output_size) elif pool_type == 'avgmax':
x = 0.5 * (x_avg + x_max) x = adaptive_avgmax_pool2d(x, output_size)
elif pool_type == 'catavgmax':
x = adaptive_catavgmax_pool2d(x, output_size)
elif pool_type == 'max': elif pool_type == 'max':
x = F.adaptive_max_pool2d(x, output_size) x = F.adaptive_max_pool2d(x, output_size)
else: else:
x = F.adaptive_avg_pool2d(x, output_size) assert False, 'Invalid pool type: %s' % pool_type
return x return x
class AdaptiveAvgMaxPool2d(torch.nn.Module): class AdaptiveAvgMaxPool2d(torch.nn.Module):
def __init__(self, output_size=1):
super(AdaptiveAvgMaxPool2d, self).__init__()
self.output_size = output_size
def forward(self, x):
return adaptive_avgmax_pool2d(x, self.output_size)
class AdaptiveCatAvgMaxPool2d(torch.nn.Module):
def __init__(self, output_size=1):
super(AdaptiveCatAvgMaxPool2d, self).__init__()
self.output_size = output_size
def forward(self, x):
return adaptive_catavgmax_pool2d(x, self.output_size)
class SelectAdaptivePool2d(torch.nn.Module):
"""Selectable global pooling layer with dynamic input kernel size """Selectable global pooling layer with dynamic input kernel size
""" """
def __init__(self, output_size=1, pool_type='avg'): def __init__(self, output_size=1, pool_type='avg'):
super(AdaptiveAvgMaxPool2d, self).__init__() super(SelectAdaptivePool2d, self).__init__()
self.output_size = output_size self.output_size = output_size
self.pool_type = pool_type self.pool_type = pool_type
if pool_type == 'avgmax': if pool_type == 'avgmax':
self.pool = nn.ModuleList([nn.AdaptiveAvgPool2d(output_size), nn.AdaptiveMaxPool2d(output_size)]) self.pool = AdaptiveAvgMaxPool2d(output_size)
elif pool_type == 'catavgmax':
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
elif pool_type == 'max': elif pool_type == 'max':
self.pool = nn.AdaptiveMaxPool2d(output_size) self.pool = nn.AdaptiveMaxPool2d(output_size)
else: else:
@ -45,11 +86,10 @@ class AdaptiveAvgMaxPool2d(torch.nn.Module):
self.pool = nn.AdaptiveAvgPool2d(output_size) self.pool = nn.AdaptiveAvgPool2d(output_size)
def forward(self, x): def forward(self, x):
if self.pool_type == 'avgmax': return self.pool(x)
x = 0.5 * torch.sum(torch.stack([p(x) for p in self.pool]), 0).squeeze(dim=0)
else: def feat_mult(self):
x = self.pool(x) return adaptive_pool_feat_mult(self.pool_type)
return x
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + ' (' \ return self.__class__.__name__ + ' (' \

@ -83,7 +83,6 @@ def densenet161(pretrained=False, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
print(kwargs)
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs) model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs)
if pretrained: if pretrained:
state_dict = model_zoo.load_url(model_urls['densenet161']) state_dict = model_zoo.load_url(model_urls['densenet161'])
@ -193,7 +192,7 @@ class DenseNet(nn.Module):
x = self.features(x) x = self.features(x)
x = F.relu(x, inplace=True) x = F.relu(x, inplace=True)
if pool: if pool:
x = adaptive_avgmax_pool2d(x, self.global_pool) x = select_adaptive_pool2d(x, self.global_pool)
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
return x return x

@ -16,7 +16,7 @@ import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo import torch.utils.model_zoo as model_zoo
from collections import OrderedDict from collections import OrderedDict
from .adaptive_avgmax_pool import adaptive_avgmax_pool2d from .adaptive_avgmax_pool import select_adaptive_pool2d
__all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107'] __all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107']
@ -41,31 +41,31 @@ model_urls = {
} }
def dpn68(num_classes=1000, pretrained=False, test_time_pool=7): def dpn68(num_classes=1000, pretrained=False):
model = DPN( model = DPN(
small=True, num_init_features=10, k_r=128, groups=32, small=True, num_init_features=10, k_r=128, groups=32,
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64),
num_classes=num_classes, test_time_pool=test_time_pool) num_classes=num_classes)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['dpn68'])) model.load_state_dict(model_zoo.load_url(model_urls['dpn68']))
return model return model
def dpn68b(num_classes=1000, pretrained=False, test_time_pool=7): def dpn68b(num_classes=1000, pretrained=False):
model = DPN( model = DPN(
small=True, num_init_features=10, k_r=128, groups=32, small=True, num_init_features=10, k_r=128, groups=32,
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64),
num_classes=num_classes, test_time_pool=test_time_pool) num_classes=num_classes)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['dpn68b_extra'])) model.load_state_dict(model_zoo.load_url(model_urls['dpn68b_extra']))
return model return model
def dpn92(num_classes=1000, pretrained=False, test_time_pool=7, extra=True): def dpn92(num_classes=1000, pretrained=False, extra=True):
model = DPN( model = DPN(
num_init_features=64, k_r=96, groups=32, num_init_features=64, k_r=96, groups=32,
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
num_classes=num_classes, test_time_pool=test_time_pool) num_classes=num_classes)
if pretrained: if pretrained:
if extra: if extra:
model.load_state_dict(model_zoo.load_url(model_urls['dpn92_extra'])) model.load_state_dict(model_zoo.load_url(model_urls['dpn92_extra']))
@ -74,31 +74,31 @@ def dpn92(num_classes=1000, pretrained=False, test_time_pool=7, extra=True):
return model return model
def dpn98(num_classes=1000, pretrained=False, test_time_pool=7): def dpn98(num_classes=1000, pretrained=False):
model = DPN( model = DPN(
num_init_features=96, k_r=160, groups=40, num_init_features=96, k_r=160, groups=40,
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128),
num_classes=num_classes, test_time_pool=test_time_pool) num_classes=num_classes)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['dpn98'])) model.load_state_dict(model_zoo.load_url(model_urls['dpn98']))
return model return model
def dpn131(num_classes=1000, pretrained=False, test_time_pool=7): def dpn131(num_classes=1000, pretrained=False):
model = DPN( model = DPN(
num_init_features=128, k_r=160, groups=40, num_init_features=128, k_r=160, groups=40,
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128),
num_classes=num_classes, test_time_pool=test_time_pool) num_classes=num_classes)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['dpn131'])) model.load_state_dict(model_zoo.load_url(model_urls['dpn131']))
return model return model
def dpn107(num_classes=1000, pretrained=False, test_time_pool=7): def dpn107(num_classes=1000, pretrained=False):
model = DPN( model = DPN(
num_init_features=128, k_r=200, groups=50, num_init_features=128, k_r=200, groups=50,
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128),
num_classes=num_classes, test_time_pool=test_time_pool) num_classes=num_classes)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['dpn107_extra'])) model.load_state_dict(model_zoo.load_url(model_urls['dpn107_extra']))
return model return model
@ -212,10 +212,9 @@ class DualPathBlock(nn.Module):
class DPN(nn.Module): class DPN(nn.Module):
def __init__(self, small=False, num_init_features=64, k_r=96, groups=32, def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
num_classes=1000, test_time_pool=0, fc_act=nn.ELU(inplace=True)): num_classes=1000, fc_act=nn.ELU(inplace=True)):
super(DPN, self).__init__() super(DPN, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.test_time_pool = test_time_pool
self.b = b self.b = b
bw_factor = 1 if small else 4 bw_factor = 1 if small else 4
@ -287,20 +286,12 @@ class DPN(nn.Module):
def forward_features(self, x, pool=True): def forward_features(self, x, pool=True):
x = self.features(x) x = self.features(x)
if pool: if pool:
x = adaptive_avgmax_pool2d(x, pool_type='avg') x = select_adaptive_pool2d(x, pool_type='avg')
x = x.view(x.size(0), -1)
return x return x
def forward(self, x): def forward(self, x):
x = self.features(x) x = self.forward_features(x)
if not self.training and self.test_time_pool: out = self.classifier(x)
x = F.avg_pool2d(x, kernel_size=self.test_time_pool, stride=1)
out = self.classifier(x)
# The extra test time pool should be pooling an img_size//32 - 6 size patch
out = adaptive_avgmax_pool2d(out, pool_type='avgmax')
else:
x = adaptive_avgmax_pool2d(x, pool_type='avg')
out = self.classifier(x)
return out.view(out.size(0), -1) return out.view(out.size(0), -1)

File diff suppressed because it is too large Load Diff

@ -305,7 +305,7 @@ class InceptionResnetV2(nn.Module):
x = self.block8(x) x = self.block8(x)
x = self.conv2d_7b(x) x = self.conv2d_7b(x)
if pool: if pool:
x = adaptive_avgmax_pool2d(x, self.global_pool) x = select_adaptive_pool2d(x, self.global_pool)
#x = F.avg_pool2d(x, 8, count_include_pad=False) #x = F.avg_pool2d(x, 8, count_include_pad=False)
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
return x return x

@ -272,7 +272,7 @@ class InceptionV4(nn.Module):
def forward_features(self, x, pool=True): def forward_features(self, x, pool=True):
x = self.features(x) x = self.features(x)
if pool: if pool:
x = adaptive_avgmax_pool2d(x, self.global_pool, count_include_pad=False) x = select_adaptive_pool2d(x, self.global_pool, count_include_pad=False)
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
return x return x

@ -1,15 +1,16 @@
import torch import torch
import os import os
from collections import OrderedDict
from .inception_v4 import inception_v4 from .inception_v4 import inception_v4
from .inception_resnet_v2 import inception_resnet_v2 from .inception_resnet_v2 import inception_resnet_v2
from .densenet import densenet161, densenet121, densenet169, densenet201 from .densenet import densenet161, densenet121, densenet169, densenet201
from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152 from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152, \
from .fbresnet200 import fbresnet200 resnext50_32x4d, resnext101_32x4d, resnext101_64x4d, resnext152_32x4d
from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107 from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107
from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \ from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \
seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d
from .resnext import resnext50, resnext101, resnext152 #from .resnext import resnext50, resnext101, resnext152
from .xception import xception from .xception import xception
model_config_dict = { model_config_dict = {
@ -57,26 +58,18 @@ def create_model(
checkpoint_path='', checkpoint_path='',
**kwargs): **kwargs):
test_time_pool = kwargs.pop('test_time_pool') if 'test_time_pool' in kwargs else 0
if model_name == 'dpn68': if model_name == 'dpn68':
model = dpn68( model = dpn68(num_classes=num_classes, pretrained=pretrained)
num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool)
elif model_name == 'dpn68b': elif model_name == 'dpn68b':
model = dpn68b( model = dpn68b(num_classes=num_classes, pretrained=pretrained)
num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool)
elif model_name == 'dpn92': elif model_name == 'dpn92':
model = dpn92( model = dpn92(num_classes=num_classes, pretrained=pretrained)
num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool)
elif model_name == 'dpn98': elif model_name == 'dpn98':
model = dpn98( model = dpn98(num_classes=num_classes, pretrained=pretrained)
num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool)
elif model_name == 'dpn131': elif model_name == 'dpn131':
model = dpn131( model = dpn131(num_classes=num_classes, pretrained=pretrained)
num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool)
elif model_name == 'dpn107': elif model_name == 'dpn107':
model = dpn107( model = dpn107(num_classes=num_classes, pretrained=pretrained)
num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool)
elif model_name == 'resnet18': elif model_name == 'resnet18':
model = resnet18(num_classes=num_classes, pretrained=pretrained, **kwargs) model = resnet18(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnet34': elif model_name == 'resnet34':
@ -99,8 +92,6 @@ def create_model(
model = inception_resnet_v2(num_classes=num_classes, pretrained=pretrained, **kwargs) model = inception_resnet_v2(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'inception_v4': elif model_name == 'inception_v4':
model = inception_v4(num_classes=num_classes, pretrained=pretrained, **kwargs) model = inception_v4(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'fbresnet200':
model = fbresnet200(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet18': elif model_name == 'seresnet18':
model = seresnet18(num_classes=num_classes, pretrained=pretrained, **kwargs) model = seresnet18(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet34': elif model_name == 'seresnet34':
@ -117,12 +108,14 @@ def create_model(
model = seresnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) model = seresnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnext101_32x4d': elif model_name == 'seresnext101_32x4d':
model = seresnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) model = seresnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext50': elif model_name == 'resnext50_32x4d':
model = resnext50(num_classes=num_classes, pretrained=pretrained, **kwargs) model = resnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext101': elif model_name == 'resnext101_32x4d':
model = resnext101(num_classes=num_classes, pretrained=pretrained, **kwargs) model = resnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext152': elif model_name == 'resnext101_64x4d':
model = resnext152(num_classes=num_classes, pretrained=pretrained, **kwargs) model = resnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext152_32x4d':
model = resnext152_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'xception': elif model_name == 'xception':
model = xception(num_classes=num_classes, pretrained=pretrained) model = xception(num_classes=num_classes, pretrained=pretrained)
else: else:
@ -136,13 +129,22 @@ def create_model(
def load_checkpoint(model, checkpoint_path): def load_checkpoint(model, checkpoint_path):
if checkpoint_path is not None and os.path.isfile(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path):
print('Loading checkpoint', checkpoint_path) print("=> Loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict']) new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('module'):
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
else: else:
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
return True
else: else:
print("Error: No checkpoint found at %s." % checkpoint_path) print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
return False

@ -7,7 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math import math
import torch.utils.model_zoo as model_zoo import torch.utils.model_zoo as model_zoo
from .adaptive_avgmax_pool import AdaptiveAvgMaxPool2d from .adaptive_avgmax_pool import SelectAdaptivePool2d
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
@ -29,8 +29,13 @@ def conv3x3(in_planes, out_planes, stride=1):
class BasicBlock(nn.Module): class BasicBlock(nn.Module):
expansion = 1 expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0): def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, drop_rate=0.0):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
assert base_width == 64, 'BasicBlock doest not support changing base width'
self.conv1 = conv3x3(inplanes, planes, stride) self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes) self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
@ -65,14 +70,18 @@ class BasicBlock(nn.Module):
class Bottleneck(nn.Module): class Bottleneck(nn.Module):
expansion = 4 expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0): def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, drop_rate=0.0):
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes) width = int(math.floor(planes * (base_width / 64)) * cardinality)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False) self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes) self.bn1 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
padding=1, groups=cardinality, bias=False)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4) self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.downsample = downsample self.downsample = downsample
@ -108,10 +117,13 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module): class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, def __init__(self, block, layers, num_classes=1000,
cardinality=1, base_width=64,
drop_rate=0.0, block_drop_rate=0.0, drop_rate=0.0, block_drop_rate=0.0,
global_pool='avg'): global_pool='avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.inplanes = 64 self.inplanes = 64
self.cardinality = cardinality
self.base_width = base_width
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.expansion = block.expansion self.expansion = block.expansion
super(ResNet, self).__init__() super(ResNet, self).__init__()
@ -123,31 +135,29 @@ class ResNet(nn.Module):
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, drop_rate=block_drop_rate) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, drop_rate=block_drop_rate)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, drop_rate=block_drop_rate) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, drop_rate=block_drop_rate)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, drop_rate=block_drop_rate) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, drop_rate=block_drop_rate)
self.global_pool = AdaptiveAvgMaxPool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_features = 512 * block.expansion self.num_features = 512 * block.expansion
self.fc = nn.Linear(self.num_features, num_classes) self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1) nn.init.constant_(m.weight, 1.)
m.bias.data.zero_() nn.init.constant_(m.bias, 0.)
def _make_layer(self, block, planes, blocks, stride=1, drop_rate=0.): def _make_layer(self, block, planes, blocks, stride=1, drop_rate=0.):
downsample = None downsample = None
if stride != 1 or self.inplanes != planes * block.expansion: if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential( downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion), nn.BatchNorm2d(planes * block.expansion),
) )
layers = [block(self.inplanes, planes, stride, downsample, drop_rate)] layers = [block(self.inplanes, planes, stride, downsample, self.cardinality, self.base_width, drop_rate)]
self.inplanes = planes * block.expansion self.inplanes = planes * block.expansion
for i in range(1, blocks): for i in range(1, blocks):
layers.append(block(self.inplanes, planes)) layers.append(block(self.inplanes, planes, cardinality=self.cardinality, base_width=self.base_width))
return nn.Sequential(*layers) return nn.Sequential(*layers)
@ -155,11 +165,11 @@ class ResNet(nn.Module):
return self.fc return self.fc
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = AdaptiveAvgMaxPool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes self.num_classes = num_classes
del self.fc del self.fc
if num_classes: if num_classes:
self.fc = nn.Linear(512 * self.expansion, num_classes) self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
else: else:
self.fc = None self.fc = None
@ -244,4 +254,52 @@ def resnet152(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model return model
def resnext50_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs):
"""Constructs a ResNeXt50-32x4d model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
"""
model = ResNet(
Bottleneck, [3, 4, 6, 3], cardinality=cardinality, base_width=base_width, **kwargs)
return model
def resnext101_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs):
"""Constructs a ResNeXt-101 model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
"""
model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs)
return model
def resnext101_64x4d(cardinality=64, base_width=4, pretrained=False, **kwargs):
"""Constructs a ResNeXt101-64x4d model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
"""
model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs)
return model
def resnext152_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs):
"""Constructs a ResNeXt152-32x4d model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
"""
model = ResNet(
Bottleneck, [3, 8, 36, 3], cardinality=cardinality, base_width=base_width, **kwargs)
return model

@ -1,175 +0,0 @@
import torch.nn as nn
import torch.nn.functional as F
import math
import torch.utils.model_zoo as model_zoo
from models.adaptive_avgmax_pool import AdaptiveAvgMaxPool2d
__all__ = ['ResNeXt', 'resnext50', 'resnext101', 'resnext152']
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class ResNeXtBottleneckC(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=32, base_width=4):
super(ResNeXtBottleneckC, self).__init__()
width = math.floor(planes * (base_width / 64)) * cardinality
self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
padding=1, bias=False, groups=cardinality)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNeXt(nn.Module):
def __init__(self, block, layers, num_classes=1000, cardinality=32, base_width=4,
drop_rate=0., global_pool='avg'):
self.num_classes = num_classes
self.inplanes = 64
self.cardinality = cardinality
self.base_width = base_width
self.drop_rate = drop_rate
super(ResNeXt, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = AdaptiveAvgMaxPool2d(pool_type=global_pool)
self.num_features = 512 * block.expansion
self.fc = nn.Linear(self.num_features, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = [block(self.inplanes, planes, stride, downsample, self.cardinality, self.base_width)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, 1, None, self.cardinality, self.base_width))
return nn.Sequential(*layers)
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.avgpool = AdaptiveAvgMaxPool2d(pool_type=global_pool)
self.num_classes = num_classes
del self.fc
if num_classes:
self.fc = nn.Linear(self.num_features, num_classes)
else:
self.fc = None
def forward_features(self, x, pool=True):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if pool:
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
def forward(self, x):
x = self.forward_features(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
return x
def resnext50(cardinality=32, base_width=4, pretrained=False, **kwargs):
"""Constructs a ResNeXt-50 model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
shortcut ('A'|'B'|'C'): 'B' use 1x1 conv to downsample, 'C' use 1x1 conv on every residual connection
"""
model = ResNeXt(
ResNeXtBottleneckC, [3, 4, 6, 3], cardinality=cardinality, base_width=base_width, **kwargs)
return model
def resnext101(cardinality=32, base_width=4, pretrained=False, **kwargs):
"""Constructs a ResNeXt-101 model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
shortcut ('A'|'B'|'C'): 'B' use 1x1 conv to downsample, 'C' use 1x1 conv on every residual connection
"""
model = ResNeXt(
ResNeXtBottleneckC, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs)
return model
def resnext152(cardinality=32, base_width=4, pretrained=False, **kwargs):
"""Constructs a ResNeXt-152 model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
shortcut ('A'|'B'|'C'): 'B' use 1x1 conv to downsample, 'C' use 1x1 conv on every residual connection
"""
model = ResNeXt(
ResNeXtBottleneckC, [3, 8, 36, 3], cardinality=cardinality, base_width=base_width, **kwargs)
return model

@ -9,7 +9,7 @@ import math
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils import model_zoo from torch.utils import model_zoo
from models.adaptive_avgmax_pool import AdaptiveAvgMaxPool2d from models.adaptive_avgmax_pool import SelectAdaptivePool2d
__all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152', __all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152',
'seresnext50_32x4d', 'seresnext101_32x4d'] 'seresnext50_32x4d', 'seresnext101_32x4d']
@ -307,7 +307,7 @@ class SENet(nn.Module):
downsample_kernel_size=downsample_kernel_size, downsample_kernel_size=downsample_kernel_size,
downsample_padding=downsample_padding downsample_padding=downsample_padding
) )
self.avg_pool = AdaptiveAvgMaxPool2d(pool_type=global_pool) self.avg_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.num_features = 512 * block.expansion self.num_features = 512 * block.expansion
self.last_linear = nn.Linear(self.num_features, num_classes) self.last_linear = nn.Linear(self.num_features, num_classes)

@ -0,0 +1,27 @@
from torch import nn
import torch.nn.functional as F
from models.adaptive_avgmax_pool import adaptive_avgmax_pool2d
class TestTimePoolHead(nn.Module):
def __init__(self, base, original_pool=7):
super(TestTimePoolHead, self).__init__()
self.base = base
self.original_pool = original_pool
base_fc = self.base.get_classifier()
if isinstance(base_fc, nn.Conv2d):
self.fc = base_fc
else:
self.fc = nn.Conv2d(
self.base.num_features, self.base.num_classes, kernel_size=1, bias=True)
self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size()))
self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size()))
self.base.reset_classifier(0) # delete original fc layer
def forward(self, x):
x = self.base.forward_features(x, pool=False)
x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1)
x = self.fc(x)
x = adaptive_avgmax_pool2d(x, 1)
return x.view(x.size(0), -1)

@ -9,9 +9,8 @@ import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.nn as nn import torch.nn as nn
import torch.nn.parallel import torch.nn.parallel
from collections import OrderedDict
from models import create_model from models import create_model, load_checkpoint, TestTimePoolHead
from data import Dataset, create_loader, get_model_meanstd from data import Dataset, create_loader, get_model_meanstd
@ -41,40 +40,26 @@ parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
def main(): def main():
args = parser.parse_args() args = parser.parse_args()
test_time_pool = False
if 'dpn' in args.model and args.img_size > 224 and not args.no_test_pool:
test_time_pool = True
# create model # create model
num_classes = 1000 num_classes = 1000
model = create_model( model = create_model(
args.model, args.model,
num_classes=num_classes, num_classes=num_classes,
pretrained=args.pretrained, pretrained=args.pretrained)
test_time_pool=test_time_pool)
print('Model %s created, param count: %d' % print('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()]))) (args.model, sum([m.numel() for m in model.parameters()])))
# optionally resume from a checkpoint # load a checkpoint
if args.checkpoint and os.path.isfile(args.checkpoint): if not args.pretrained:
print("=> loading checkpoint '{}'".format(args.checkpoint)) if not load_checkpoint(model, args.checkpoint):
checkpoint = torch.load(args.checkpoint) exit(1)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
new_state_dict = OrderedDict() test_time_pool = False
for k, v in checkpoint['state_dict'].items(): # FIXME make this work for networks with default img size != 224 and default pool k != 7
if k.startswith('module'): if args.img_size > 224 and not args.no_test_pool:
name = k[7:] # remove `module.` model = TestTimePoolHead(model)
else: test_time_pool = True
name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(checkpoint)
print("=> loaded checkpoint '{}'".format(args.checkpoint))
elif not args.pretrained:
print("=> no checkpoint found at '{}'".format(args.checkpoint))
exit(1)
if args.num_gpu > 1: if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
@ -94,7 +79,8 @@ def main():
use_prefetcher=True, use_prefetcher=True,
mean=data_mean, mean=data_mean,
std=data_std, std=data_std,
num_workers=args.workers) num_workers=args.workers,
crop_pct=1.0 if test_time_pool else None)
batch_time = AverageMeter() batch_time = AverageMeter()
losses = AverageMeter() losses = AverageMeter()

Loading…
Cancel
Save