Uniform pretrained model handling.

* All models have 'default_cfgs' dict
* load/resume/pretrained helpers factored out
* pretrained load operates on state_dict based on default_cfg
* test all models in validate
* schedule, optim factor factored out
* test time pool wrapper applied based on default_cfg
pull/1/head
Ross Wightman 5 years ago
parent 63e677d03b
commit 9c3859fb9c

@ -1,4 +1,4 @@
from data.dataset import Dataset
from data.transforms import transforms_imagenet_eval, transforms_imagenet_train, get_model_meanstd
from data.utils import create_loader
from data.transforms import *
from data.loader import create_loader
from data.random_erasing import RandomErasingTorch, RandomErasingNumpy

@ -94,6 +94,9 @@ def create_loader(
sampler = None
if distributed:
# FIXME note, doing this for validation isn't technically correct
# There currently is no fixed order distributed sampler that corrects
# for padded entries
sampler = tdata.distributed.DistributedSampler(dataset)
loader = tdata.DataLoader(

@ -1,8 +1,5 @@
from __future__ import absolute_import
#from torchvision.transforms import *
from PIL import Image
import random
import math
import numpy as np

@ -7,26 +7,57 @@ from data.random_erasing import RandomErasingNumpy
DEFAULT_CROP_PCT = 0.875
IMAGENET_DPN_MEAN = [124 / 255, 117 / 255, 104 / 255]
IMAGENET_DPN_STD = [1 / (.0167 * 255)] * 3
IMAGENET_INCEPTION_MEAN = [0.5, 0.5, 0.5]
IMAGENET_INCEPTION_STD = [0.5, 0.5, 0.5]
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
def get_mean_and_std(model, args, num_chan=3):
if hasattr(model, 'default_cfg'):
mean = model.default_cfg['mean']
std = model.default_cfg['std']
else:
if args.mean is not None:
mean = tuple(args.mean)
if len(mean) == 1:
mean = tuple(list(mean) * num_chan)
else:
assert len(mean) == num_chan
else:
mean = get_mean_by_model(args.model)
if args.std is not None:
std = tuple(args.std)
if len(std) == 1:
std = tuple(list(std) * num_chan)
else:
assert len(std) == num_chan
else:
std = get_std_by_model(args.model)
return mean, std
def get_mean_by_name(name):
if name == 'dpn':
return IMAGENET_DPN_MEAN
elif name == 'inception' or name == 'le':
return IMAGENET_INCEPTION_MEAN
else:
return IMAGENET_DEFAULT_MEAN
# FIXME replace these mean/std fn with model factory based values from config dict
def get_model_meanstd(model_name):
model_name = model_name.lower()
if 'dpn' in model_name:
return IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
elif 'ception' in model_name or 'nasnet' in model_name:
return IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
def get_std_by_name(name):
if name == 'dpn':
return IMAGENET_DPN_STD
elif name == 'inception' or name == 'le':
return IMAGENET_INCEPTION_STD
else:
return IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
return IMAGENET_DEFAULT_STD
def get_model_mean(model_name):
def get_mean_by_model(model_name):
model_name = model_name.lower()
if 'dpn' in model_name:
return IMAGENET_DPN_STD
@ -36,7 +67,7 @@ def get_model_mean(model_name):
return IMAGENET_DEFAULT_MEAN
def get_model_std(model_name):
def get_std_by_model(model_name):
model_name = model_name.lower()
if 'dpn' in model_name:
return IMAGENET_DEFAULT_STD
@ -93,8 +124,8 @@ def transforms_imagenet_train(
tfl += [
ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean) * 255,
std=torch.tensor(std) * 255)
mean=torch.tensor(mean),
std=torch.tensor(std))
]
if random_erasing > 0.:
tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True))
@ -124,11 +155,5 @@ def transforms_imagenet_eval(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
# tfl += [
# ToTensor(),
# transforms.Normalize(
# mean=torch.tensor(mean) * 255,
# std=torch.tensor(std) * 255)
# ]
return transforms.Compose(tfl)

@ -11,10 +11,11 @@ import argparse
import numpy as np
import torch
from models import create_model, load_checkpoint, TestTimePoolHead
from data import Dataset, create_loader, get_model_meanstd
from models import create_model, apply_test_time_pool
from data import Dataset, create_loader, get_mean_and_std
from utils import AverageMeter
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
parser.add_argument('data', metavar='DIR',
@ -29,6 +30,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=224, type=int,
metavar='N', help='Input image dimension')
parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@ -45,26 +48,24 @@ def main():
args = parser.parse_args()
# create model
num_classes = 1000
model = create_model(
args.model,
num_classes=num_classes,
pretrained=args.pretrained)
num_classes=args.num_classes,
in_chans=3,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint)
print('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
# load a checkpoint
if not args.pretrained:
if not load_checkpoint(model, args.checkpoint):
exit(1)
data_mean, data_std = get_mean_and_std(model, args)
model, test_time_pool = apply_test_time_pool(model, args)
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else:
model = model.cuda()
data_mean, data_std = get_model_meanstd(args.model)
loader = create_loader(
Dataset(args.data),
img_size=args.img_size,

@ -1,3 +1,4 @@
from .model_factory import create_model, load_checkpoint
from .test_time_pool import TestTimePoolHead
from models.model_factory import create_model
from models.helpers import load_checkpoint, resume_checkpoint
from models.test_time_pool import TestTimePoolHead, apply_test_time_pool

@ -5,19 +5,29 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from collections import OrderedDict
from .adaptive_avgmax_pool import *
from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import *
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import re
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-241335ed.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-6f0f7f60.pth',
'densenet201': 'https://download.pytorch.org/models/densenet201-4c113574.pth',
'densenet161': 'https://download.pytorch.org/models/densenet161-17b70270.pth',
def _cfg(url=''):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 244), 'pool_size': (7, 7),
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'features.conv0', 'classifier': 'classifier',
}
default_cfgs = {
'densenet121': _cfg(url='https://download.pytorch.org/models/densenet121-241335ed.pth'),
'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-6f0f7f60.pth'),
'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-4c113574.pth'),
'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-17b70270.pth'),
}
@ -34,59 +44,56 @@ def _filter_pretrained(state_dict):
return state_dict
def densenet121(pretrained=False, **kwargs):
def densenet121(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
default_cfg = default_cfgs['densenet121']
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
state_dict = model_zoo.load_url(model_urls['densenet121'])
model.load_state_dict(_filter_pretrained(state_dict))
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
return model
def densenet169(pretrained=False, **kwargs):
def densenet169(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
r"""Densenet-169 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
default_cfg = default_cfgs['densenet169']
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
state_dict = model_zoo.load_url(model_urls['densenet169'])
model.load_state_dict(_filter_pretrained(state_dict))
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
return model
def densenet201(pretrained=False, **kwargs):
def densenet201(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
default_cfg = default_cfgs['densenet201']
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
state_dict = model_zoo.load_url(model_urls['densenet201'])
model.load_state_dict(_filter_pretrained(state_dict))
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
return model
def densenet161(pretrained=False, **kwargs):
def densenet161(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs)
print(num_classes, in_chans, pretrained)
default_cfg = default_cfgs['densenet161']
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
state_dict = model_zoo.load_url(model_urls['densenet161'])
model.load_state_dict(_filter_pretrained(state_dict))
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
return model
@ -142,14 +149,15 @@ class DenseNet(nn.Module):
num_classes (int) - number of classification classes
"""
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, global_pool='avg'):
num_init_features=64, bn_size=4, drop_rate=0,
num_classes=1000, in_chans=3, global_pool='avg'):
self.global_pool = global_pool
self.num_classes = num_classes
super(DenseNet, self).__init__()
# First convolution
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
('norm0', nn.BatchNorm2d(num_init_features)),
('relu0', nn.ReLU(inplace=True)),
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
@ -172,7 +180,7 @@ class DenseNet(nn.Module):
self.features.add_module('norm5', nn.BatchNorm2d(num_features))
# Linear layer
self.classifier = torch.nn.Linear(num_features, num_classes)
self.classifier = nn.Linear(num_features, num_classes)
self.num_features = num_features
@ -184,7 +192,7 @@ class DenseNet(nn.Module):
self.num_classes = num_classes
del self.classifier
if num_classes:
self.classifier = torch.nn.Linear(self.num_features, num_classes)
self.classifier = nn.Linear(self.num_features, num_classes)
else:
self.classifier = None

@ -13,94 +13,108 @@ import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from collections import OrderedDict
from .adaptive_avgmax_pool import select_adaptive_pool2d
from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import select_adaptive_pool2d
from data.transforms import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
__all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107']
model_urls = {
def _cfg(url=''):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD,
'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier',
}
default_cfgs = {
'dpn68':
'http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth',
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth'),
'dpn68b_extra':
'http://data.lip6.fr/cadene/pretrainedmodels/'
'dpn68b_extra-84854c156.pth',
'dpn92': '',
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68b_extra-84854c156.pth'),
'dpn92_extra':
'http://data.lip6.fr/cadene/pretrainedmodels/'
'dpn92_extra-b040e4a9b.pth',
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'),
'dpn98':
'http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth',
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth'),
'dpn131':
'http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth',
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth'),
'dpn107_extra':
'http://data.lip6.fr/cadene/pretrainedmodels/'
'dpn107_extra-1ac7121e2.pth'
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn107_extra-1ac7121e2.pth')
}
def dpn68(num_classes=1000, pretrained=False):
def dpn68(num_classes=1000, in_chans=3, pretrained=False):
default_cfg = default_cfgs['dpn68']
model = DPN(
small=True, num_init_features=10, k_r=128, groups=32,
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64),
num_classes=num_classes)
num_classes=num_classes, in_chans=in_chans)
model.default_cfg = default_cfg
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['dpn68']))
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def dpn68b(num_classes=1000, pretrained=False):
def dpn68b(num_classes=1000, in_chans=3, pretrained=False):
default_cfg = default_cfgs['dpn68b_extra']
model = DPN(
small=True, num_init_features=10, k_r=128, groups=32,
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64),
num_classes=num_classes)
num_classes=num_classes, in_chans=in_chans)
model.default_cfg = default_cfg
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['dpn68b_extra']))
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def dpn92(num_classes=1000, pretrained=False, extra=True):
def dpn92(num_classes=1000, in_chans=3, pretrained=False):
default_cfg = default_cfgs['dpn92_extra']
model = DPN(
num_init_features=64, k_r=96, groups=32,
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
num_classes=num_classes)
num_classes=num_classes, in_chans=in_chans)
model.default_cfg = default_cfg
if pretrained:
if extra:
model.load_state_dict(model_zoo.load_url(model_urls['dpn92_extra']))
else:
model.load_state_dict(model_zoo.load_url(model_urls['dpn92']))
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def dpn98(num_classes=1000, pretrained=False):
def dpn98(num_classes=1000, in_chans=3, pretrained=False):
default_cfg = default_cfgs['dpn98']
model = DPN(
num_init_features=96, k_r=160, groups=40,
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128),
num_classes=num_classes)
num_classes=num_classes, in_chans=in_chans)
model.default_cfg = default_cfg
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['dpn98']))
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def dpn131(num_classes=1000, pretrained=False):
def dpn131(num_classes=1000, in_chans=3, pretrained=False):
default_cfg = default_cfgs['dpn131']
model = DPN(
num_init_features=128, k_r=160, groups=40,
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128),
num_classes=num_classes)
num_classes=num_classes, in_chans=in_chans)
model.default_cfg = default_cfg
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['dpn131']))
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def dpn107(num_classes=1000, pretrained=False):
def dpn107(num_classes=1000, in_chans=3, pretrained=False):
default_cfg = default_cfgs['dpn107_extra']
model = DPN(
num_init_features=128, k_r=200, groups=50,
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128),
num_classes=num_classes)
num_classes=num_classes, in_chans=in_chans)
model.default_cfg = default_cfg
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['dpn107_extra']))
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@ -128,11 +142,11 @@ class BnActConv2d(nn.Module):
class InputBlock(nn.Module):
def __init__(self, num_init_features, kernel_size=7,
def __init__(self, num_init_features, kernel_size=7, in_chans=3,
padding=3, activation_fn=nn.ReLU(inplace=True)):
super(InputBlock, self).__init__()
self.conv = nn.Conv2d(
3, num_init_features, kernel_size=kernel_size, stride=2, padding=padding, bias=False)
in_chans, num_init_features, kernel_size=kernel_size, stride=2, padding=padding, bias=False)
self.bn = nn.BatchNorm2d(num_init_features, eps=0.001)
self.act = activation_fn
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@ -212,7 +226,7 @@ class DualPathBlock(nn.Module):
class DPN(nn.Module):
def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
num_classes=1000, fc_act=nn.ELU(inplace=True)):
num_classes=1000, in_chans=3, fc_act=nn.ELU(inplace=True)):
super(DPN, self).__init__()
self.num_classes = num_classes
self.b = b
@ -222,9 +236,11 @@ class DPN(nn.Module):
# conv1
if small:
blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=3, padding=1)
blocks['conv1_1'] = InputBlock(
num_init_features, in_chans=in_chans, kernel_size=3, padding=1)
else:
blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=7, padding=3)
blocks['conv1_1'] = InputBlock(
num_init_features, in_chans=in_chans, kernel_size=7, padding=3)
# conv2
bw = 64 * bw_factor

@ -0,0 +1,89 @@
import torch
import torch.utils.model_zoo as model_zoo
import os
from collections import OrderedDict
def load_checkpoint(model, checkpoint_path):
if checkpoint_path and os.path.isfile(checkpoint_path):
print("=> Loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('module'):
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(checkpoint)
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
else:
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def resume_checkpoint(model, checkpoint_path, start_epoch=None):
start_epoch = 0 if start_epoch is None else start_epoch
optimizer_state = None
if os.path.isfile(checkpoint_path):
print("=> loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('module'):
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
if 'optimizer' in checkpoint:
optimizer_state = checkpoint['optimizer']
print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
else:
model.load_state_dict(checkpoint)
return optimizer_state, start_epoch
else:
print("=> No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=None):
state_dict = model_zoo.load_url(default_cfg['url'])
if in_chans == 1:
conv1_name = default_cfg['first_conv']
print('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
conv1_weight = state_dict[conv1_name + '.weight']
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
elif in_chans != 3:
assert False, "Invalid in_chans for pretrained weights"
strict = True
classifier_name = default_cfg['classifier']
if num_classes == 1000 and default_cfg['num_classes'] == 1001:
# special case for imagenet trained models with extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
elif num_classes != default_cfg['num_classes']:
# completely discard fully connected for all other differences between pretrained and created model
del state_dict[classifier_name + '.weight']
del state_dict[classifier_name + '.bias']
strict = False
if filter_fn is not None:
state_dict = filter_fn(state_dict)
model.load_state_dict(state_dict, strict=strict)

@ -5,12 +5,18 @@ based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import numpy as np
from .adaptive_avgmax_pool import *
model_urls = {
'imagenet': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth'
from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import *
from data.transforms import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
default_cfgs = {
'inception_resnet_v2': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'conv2d_1a.conv', 'classifier': 'last_linear',
}
}
@ -204,12 +210,14 @@ class Block8(nn.Module):
class InceptionResnetV2(nn.Module):
def __init__(self, num_classes=1001, drop_rate=0., global_pool='avg'):
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'):
super(InceptionResnetV2, self).__init__()
self.drop_rate = drop_rate
self.global_pool = global_pool
self.num_classes = num_classes
self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
self.num_features = 1536
self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.maxpool_3a = nn.MaxPool2d(3, stride=2)
@ -265,29 +273,21 @@ class InceptionResnetV2(nn.Module):
Block8(scale=0.20)
)
self.block8 = Block8(noReLU=True)
self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
self.num_features = 1536
self.last_linear = nn.Linear(1536, num_classes)
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
self.last_linear = nn.Linear(self.num_features, num_classes)
def get_classifier(self):
return self.classif
return self.last_linear
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = global_pool
self.num_classes = num_classes
del self.classif
del self.last_linear
if num_classes:
self.last_linear = torch.nn.Linear(1536, num_classes)
self.last_linear = torch.nn.Linear(self.num_features, num_classes)
else:
self.last_linear = None
def trim_classifier(self, trim=1):
self.num_classes -= trim
new_last_linear = nn.Linear(1536, self.num_classes)
new_last_linear.weight.data = self.last_linear.weight.data[trim:]
new_last_linear.bias.data = self.last_linear.bias.data[trim:]
self.last_linear = new_last_linear
def forward_features(self, x, pool=True):
x = self.conv2d_1a(x)
x = self.conv2d_2a(x)
@ -318,19 +318,15 @@ class InceptionResnetV2(nn.Module):
return x
def inception_resnet_v2(pretrained=False, num_classes=1000, **kwargs):
def inception_resnet_v2(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
r"""InceptionResnetV2 model architecture from the
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
Args:
pretrained ('string'): If True, returns a model pre-trained on ImageNet
"""
extra_class = 1 if pretrained else 0
model = InceptionResnetV2(num_classes=num_classes + extra_class, **kwargs)
default_cfg = default_cfgs['inception_resnet_v2']
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
print('Loading pretrained from %s' % model_urls['imagenet'])
model.load_state_dict(model_zoo.load_url(model_urls['imagenet']))
model.trim_classifier()
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -5,11 +5,18 @@ based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from .adaptive_avgmax_pool import *
model_urls = {
'imagenet': 'http://webia.lip6.fr/~cadene/Downloads/inceptionv4-97ef9c30.pth'
from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import *
from data.transforms import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
default_cfgs = {
'inception_v4': {
'url': 'http://webia.lip6.fr/~cadene/Downloads/inceptionv4-97ef9c30.pth',
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'features.0.conv', 'classifier': 'classif',
}
}
@ -230,13 +237,15 @@ class Inception_C(nn.Module):
class InceptionV4(nn.Module):
def __init__(self, num_classes=1001, drop_rate=0., global_pool='avg'):
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'):
super(InceptionV4, self).__init__()
self.drop_rate = drop_rate
self.global_pool = global_pool
self.num_classes = num_classes
self.num_features = 1536
self.features = nn.Sequential(
BasicConv2d(3, 32, kernel_size=3, stride=2),
BasicConv2d(in_chans, 32, kernel_size=3, stride=2),
BasicConv2d(32, 32, kernel_size=3, stride=1),
BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1),
Mixed_3a(),
@ -259,7 +268,7 @@ class InceptionV4(nn.Module):
Inception_C(),
Inception_C(),
)
self.classif = nn.Linear(1536, num_classes)
self.classif = nn.Linear(self.num_features, num_classes)
def get_classifier(self):
return self.classif
@ -267,12 +276,12 @@ class InceptionV4(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = global_pool
self.num_classes = num_classes
self.classif = nn.Linear(1536, num_classes)
self.classif = nn.Linear(self.num_features, num_classes)
def forward_features(self, x, pool=True):
x = self.features(x)
if pool:
x = select_adaptive_pool2d(x, self.global_pool, count_include_pad=False)
x = select_adaptive_pool2d(x, self.global_pool)
x = x.view(x.size(0), -1)
return x
@ -284,10 +293,12 @@ class InceptionV4(nn.Module):
return x
def inception_v4(pretrained=False, num_classes=1001, **kwargs):
model = InceptionV4(num_classes=num_classes, **kwargs)
def inception_v4(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['inception_v4']
model = InceptionV4(num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['imagenet']))
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -1,155 +1,34 @@
import torch
import os
from collections import OrderedDict
from .inception_v4 import inception_v4
from .inception_resnet_v2 import inception_resnet_v2
from .densenet import densenet161, densenet121, densenet169, densenet201
from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152, \
from models.inception_v4 import inception_v4
from models.inception_resnet_v2 import inception_resnet_v2
from models.densenet import densenet161, densenet121, densenet169, densenet201
from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152, \
resnext50_32x4d, resnext101_32x4d, resnext101_64x4d, resnext152_32x4d
from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107
from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \
from models.dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107
from models.senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \
seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d
#from .resnext import resnext50, resnext101, resnext152
from .xception import xception
from .pnasnet import pnasnet5large
from models.xception import xception
from models.pnasnet import pnasnet5large
model_config_dict = {
'resnet18': {
'model_name': 'resnet18', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'resnet34': {
'model_name': 'resnet34', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'resnet50': {
'model_name': 'resnet50', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'resnet101': {
'model_name': 'resnet101', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'resnet152': {
'model_name': 'resnet152', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'densenet121': {
'model_name': 'densenet121', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'densenet169': {
'model_name': 'densenet169', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'densenet201': {
'model_name': 'densenet201', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'densenet161': {
'model_name': 'densenet161', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'dpn107': {
'model_name': 'dpn107', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'dpn92_extra': {
'model_name': 'dpn92', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'dpn92': {
'model_name': 'dpn92', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'dpn68': {
'model_name': 'dpn68', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'dpn68b': {
'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'dpn68b_extra': {
'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'inception_resnet_v2': {
'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
'xception': {
'model_name': 'xception', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
'pnasnet5large': {
'model_name': 'pnasnet5large', 'num_classes': 1000, 'input_size': 331, 'normalizer': 'le'}
}
from models.helpers import load_checkpoint
def create_model(
model_name='resnet50',
pretrained=None,
num_classes=1000,
in_chans=3,
checkpoint_path='',
**kwargs):
if model_name == 'dpn68':
model = dpn68(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'dpn68b':
model = dpn68b(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'dpn92':
model = dpn92(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'dpn98':
model = dpn98(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'dpn131':
model = dpn131(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'dpn107':
model = dpn107(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'resnet18':
model = resnet18(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnet34':
model = resnet34(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnet50':
model = resnet50(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnet101':
model = resnet101(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnet152':
model = resnet152(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'densenet121':
model = densenet121(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'densenet161':
model = densenet161(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'densenet169':
model = densenet169(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'densenet201':
model = densenet201(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'inception_resnet_v2':
model = inception_resnet_v2(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'inception_v4':
model = inception_v4(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet18':
model = seresnet18(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet34':
model = seresnet34(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet50':
model = seresnet50(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet101':
model = seresnet101(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet152':
model = seresnet152(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnext26_32x4d':
model = seresnext26_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnext50_32x4d':
model = seresnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnext101_32x4d':
model = seresnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext50_32x4d':
model = resnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext101_32x4d':
model = resnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext101_64x4d':
model = resnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext152_32x4d':
model = resnext152_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'xception':
model = xception(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'pnasnet5large':
model = pnasnet5large(num_classes=num_classes, pretrained=pretrained)
margs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained)
if model_name in globals():
create_fn = globals()[model_name]
model = create_fn(**margs, **kwargs)
else:
assert False and "Invalid model"
raise RuntimeError('Unknown model (%s)' % model_name)
if checkpoint_path and not pretrained:
print(checkpoint_path)
load_checkpoint(model, checkpoint_path)
return model
def load_checkpoint(model, checkpoint_path):
if checkpoint_path and os.path.isfile(checkpoint_path):
print("=> Loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('module'):
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(checkpoint)
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
return True
else:
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
return False

@ -3,29 +3,23 @@ from collections import OrderedDict
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
pretrained_settings = {
from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
default_cfgs = {
'pnasnet5large': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',
'input_space': 'RGB',
'input_size': [3, 331, 331],
'input_range': [0, 1],
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'num_classes': 1000
},
'imagenet+background': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',
'input_space': 'RGB',
'input_size': [3, 331, 331],
'input_range': [0, 1],
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'num_classes': 1001
}
}
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',
'input_size': (3, 331, 331),
'pool_size': (11, 11),
'mean': (0.5, 0.5, 0.5),
'std': (0.5, 0.5, 0.5),
'crop_pct': 0.8975,
'num_classes': 1001,
'first_conv': 'conv_0.conv',
'classifier': 'last_linear',
},
}
@ -288,13 +282,14 @@ class Cell(CellBase):
class PNASNet5Large(nn.Module):
def __init__(self, num_classes=1001):
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0.5, global_pool='avg'):
super(PNASNet5Large, self).__init__()
self.num_classes = num_classes
self.num_features = 4320
self.drop_rate = drop_rate
self.conv_0 = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(3, 96, kernel_size=3, stride=2, bias=False)),
('conv', nn.Conv2d(in_chans, 96, kernel_size=3, stride=2, bias=False)),
('bn', nn.BatchNorm2d(96, eps=0.001))
]))
self.cell_stem_0 = CellStem0(in_channels_left=96, out_channels_left=54,
@ -334,18 +329,18 @@ class PNASNet5Large(nn.Module):
self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864,
in_channels_right=4320, out_channels_right=864)
self.relu = nn.ReLU()
self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0)
self.dropout = nn.Dropout(0.5)
self.last_linear = nn.Linear(self.num_features, num_classes)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
def get_classifier(self):
return self.last_linear
def reset_classifier(self, num_classes):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
del self.last_linear
if num_classes:
self.last_linear = nn.Linear(self.num_features, num_classes)
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
else:
self.last_linear = None
@ -367,38 +362,27 @@ class PNASNet5Large(nn.Module):
x_cell_11 = self.cell_11(x_cell_9, x_cell_10)
x = self.relu(x_cell_11)
if pool:
x = self.avg_pool(x)
x = self.global_pool(x)
x = x.view(x.size(0), -1)
return x
def forward(self, input):
x = self.forward_features(input)
x = self.dropout(x)
if self.drop_rate > 0:
x = F.dropout(x, self.drop_rate, training=self.training)
x = self.last_linear(x)
return x
def pnasnet5large(num_classes=1001, pretrained='imagenet'):
def pnasnet5large(num_classes=1000, in_chans=3, pretrained='imagenet'):
r"""PNASNet-5 model architecture from the
`"Progressive Neural Architecture Search"
<https://arxiv.org/abs/1712.00559>`_ paper.
"""
default_cfg = default_cfgs['pnasnet5large']
model = PNASNet5Large(num_classes=1000, in_chans=in_chans)
model.default_cfg = default_cfg
if pretrained:
settings = pretrained_settings['pnasnet5large']['imagenet']
assert num_classes == settings[
'num_classes'], 'num_classes should be {}, but is {}'.format(
settings['num_classes'], num_classes)
# both 'imagenet'&'imagenet+background' are loaded from same parameters
model = PNASNet5Large(num_classes=1001)
model.load_state_dict(model_zoo.load_url(settings['url']))
#if pretrained == 'imagenet':
new_last_linear = nn.Linear(model.last_linear.in_features, 1000)
new_last_linear.weight.data = model.last_linear.weight.data[1:]
new_last_linear.bias.data = model.last_linear.bias.data[1:]
model.last_linear = new_last_linear
else:
model = PNASNet5Large(num_classes=num_classes)
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -6,17 +6,33 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torch.utils.model_zoo as model_zoo
from .adaptive_avgmax_pool import SelectAdaptivePool2d
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d']
def _cfg(url=''):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'crop_pct': 0.875,
'first_conv': 'conv1', 'classifier': 'fc',
}
default_cfgs = {
'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'),
'resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'),
'resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'),
'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'),
'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'),
'resnext50_32x4d': _cfg(url=''),
'resnext101_32x4d': _cfg(url=''),
'resnext101_64x4d': _cfg(url=''),
'resnext152_32x4d': _cfg(url=''),
}
@ -116,7 +132,7 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000,
def __init__(self, block, layers, num_classes=1000, in_chans=3,
cardinality=1, base_width=64,
drop_rate=0.0, block_drop_rate=0.0,
global_pool='avg'):
@ -127,7 +143,7 @@ class ResNet(nn.Module):
self.drop_rate = drop_rate
self.expansion = block.expansion
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@ -197,109 +213,108 @@ class ResNet(nn.Module):
return x
def resnet18(pretrained=False, **kwargs):
def resnet18(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
default_cfg = default_cfgs['resnet18']
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def resnet34(pretrained=False, **kwargs):
def resnet34(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
default_cfg = default_cfgs['resnet34']
model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def resnet50(pretrained=False, **kwargs):
def resnet50(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
default_cfg = default_cfgs['resnet50']
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def resnet101(pretrained=False, **kwargs):
def resnet101(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
default_cfg = default_cfgs['resnet101']
model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def resnet152(pretrained=False, **kwargs):
def resnet152(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
default_cfg = default_cfgs['resnet152']
model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def resnext50_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs):
def resnext50_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNeXt50-32x4d model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
"""
default_cfg = default_cfgs['resnext50_32x4d2']
model = ResNet(
Bottleneck, [3, 4, 6, 3], cardinality=cardinality, base_width=base_width, **kwargs)
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def resnext101_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs):
def resnext101_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNeXt-101 model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
"""
default_cfg = default_cfgs['resnext101_32x4d']
model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs)
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def resnext101_64x4d(cardinality=64, base_width=4, pretrained=False, **kwargs):
def resnext101_64x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNeXt101-64x4d model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
"""
default_cfg = default_cfgs['resnext101_32x4d']
model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs)
Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def resnext152_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs):
def resnext152_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNeXt152-32x4d model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
"""
default_cfg = default_cfgs['resnext152_32x4d']
model = ResNet(
Bottleneck, [3, 8, 36, 3], cardinality=cardinality, base_width=base_width, **kwargs)
Bottleneck, [3, 8, 36, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -8,21 +8,40 @@ import math
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo
from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152',
'seresnext50_32x4d', 'seresnext101_32x4d']
model_urls = {
'senet154': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
'seresnet18': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
'seresnet34': 'https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1',
'seresnet50': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
'seresnet101': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
'seresnet152': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
'seresnext50_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
'seresnext101_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
def _cfg(url=''):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 244), 'pool_size': (7, 7),
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'crop_pct': 0.875,
'first_conv': 'layer0.conv1', 'classifier': 'last_linear',
}
default_cfgs = {
'senet154':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'),
'seresnet18':
_cfg(url=''),
'seresnet34':
_cfg(url='https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1'),
'seresnet50':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth'),
'seresnet101':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth'),
'seresnet152':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth'),
'seresnext50_32x4d':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'),
'seresnext101_32x4d':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth'),
}
@ -197,7 +216,7 @@ class SEResNetBlock(nn.Module):
class SENet(nn.Module):
def __init__(self, block, layers, groups, reduction, drop_rate=0.2,
inchans=3, inplanes=128, input_3x3=True, downsample_kernel_size=3,
in_chans=3, inplanes=128, input_3x3=True, downsample_kernel_size=3,
downsample_padding=1, num_classes=1000, global_pool='avg'):
"""
Parameters
@ -247,7 +266,7 @@ class SENet(nn.Module):
self.num_classes = num_classes
if input_3x3:
layer0_modules = [
('conv1', nn.Conv2d(inchans, 64, 3, stride=2, padding=1, bias=False)),
('conv1', nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False)),
('bn1', nn.BatchNorm2d(64)),
('relu1', nn.ReLU(inplace=True)),
('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)),
@ -260,7 +279,7 @@ class SENet(nn.Module):
else:
layer0_modules = [
('conv1', nn.Conv2d(
inchans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
('bn1', nn.BatchNorm2d(inplanes)),
('relu1', nn.ReLU(inplace=True)),
]
@ -368,99 +387,107 @@ class SENet(nn.Module):
return x
def _load_pretrained(model, url, inchans=3):
state_dict = model_zoo.load_url(url)
if inchans == 1:
conv1_weight = state_dict['conv1.weight']
state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True)
elif inchans != 3:
assert False, "Invalid inchans for pretrained weights"
model.load_state_dict(state_dict)
def senet154(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
num_classes=num_classes, **kwargs)
if pretrained:
_load_pretrained(model, model_urls['senet154'], inchans)
return model
def seresnet18(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
def seresnet18(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['seresnet18']
model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16,
inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
_load_pretrained(model, model_urls['seresnet18'], inchans)
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def seresnet34(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
def seresnet34(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['seresnet34']
model = SENet(SEResNetBlock, [3, 4, 6, 3], groups=1, reduction=16,
inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
_load_pretrained(model, model_urls['seresnet34'], inchans)
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def seresnet50(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
def seresnet50(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['seresnet50']
model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
_load_pretrained(model, model_urls['seresnet50'], inchans)
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def seresnet101(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
def seresnet101(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['seresnet101']
model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
_load_pretrained(model, model_urls['seresnet101'], inchans)
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def seresnet152(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
def seresnet152(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['seresnet152']
model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def senet154(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['senet154']
model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
_load_pretrained(model, model_urls['seresnet152'], inchans)
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def seresnext26_32x4d(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
def seresnext26_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['seresnext26_32x4d']
model = SENet(SEResNeXtBottleneck, [2, 2, 2, 2], groups=32, reduction=16,
inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
_load_pretrained(model, model_urls['se_resnext26_32x4d'], inchans)
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def seresnext50_32x4d(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
def seresnext50_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['seresnext50_32x4d']
model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
_load_pretrained(model, model_urls['seresnext50_32x4d'], inchans)
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def seresnext101_32x4d(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
def seresnext101_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['seresnext101_32x4d']
model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
_load_pretrained(model, model_urls['seresnext101_32x4d'], inchans)
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -25,3 +25,12 @@ class TestTimePoolHead(nn.Module):
x = adaptive_avgmax_pool2d(x, 1)
return x.view(x.size(0), -1)
def apply_test_time_pool(model, args):
test_time_pool = False
if args.img_size > model.default_cfg['input_size'][-1] and not args.no_test_pool:
print('Target input size (%d) > pretrained default (%d), using test time pooling' %
(args.img_size, model.default_cfg['input_size'][-1]))
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
test_time_pool = True
return model, test_time_pool

@ -26,24 +26,24 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.nn import init
from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import select_adaptive_pool2d
__all__ = ['xception']
pretrained_config = {
default_cfgs = {
'xception': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
'input_space': 'RGB',
'input_size': [3, 299, 299],
'input_range': [0, 1],
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'num_classes': 1000,
'scale': 0.8975
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
'input_size': (3, 299, 299),
'mean': (0.5, 0.5, 0.5),
'std': (0.5, 0.5, 0.5),
'num_classes': 1000,
'crop_pct': 0.8975,
'first_conv': 'conv1',
'classifier': 'fc'
# The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
}
}
}
@ -120,16 +120,18 @@ class Xception(nn.Module):
https://arxiv.org/pdf/1610.02357.pdf
"""
def __init__(self, num_classes=1000):
def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'):
""" Constructor
Args:
num_classes: number of classes
"""
super(Xception, self).__init__()
self.drop_rate = drop_rate
self.global_pool = global_pool
self.num_classes = num_classes
self.num_features = 2048
self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False)
self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False)
self.bn1 = nn.BatchNorm2d(32)
self.relu = nn.ReLU(inplace=True)
@ -173,8 +175,9 @@ class Xception(nn.Module):
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = global_pool
del self.fc
if num_classes:
self.fc = nn.Linear(self.num_features, num_classes)
@ -212,24 +215,23 @@ class Xception(nn.Module):
x = self.relu(x)
if pool:
x = F.adaptive_avg_pool2d(x, (1, 1))
x = select_adaptive_pool2d(x, pool_type=self.global_pool)
x = x.view(x.size(0), -1)
return x
def forward(self, input):
x = self.forward_features(input)
if self.drop_rate:
F.dropout(x, self.drop_rate, training=self.training)
x = self.fc(x)
return x
def xception(num_classes=1000, pretrained=False):
model = Xception(num_classes=num_classes)
def xception(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['xception']
model = Xception(num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
config = pretrained_config['xception']['imagenet']
assert num_classes == config['num_classes'], \
"num_classes should be {}, but is {}".format(config['num_classes'], num_classes)
model = Xception(num_classes=num_classes)
model.load_state_dict(model_zoo.load_url(config['url']))
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -1,2 +1,3 @@
from optim.adabound import AdaBound
from optim.nadam import Nadam
from optim.nadam import Nadam
from optim.optim_factory import create_optimizer

@ -0,0 +1,30 @@
from torch import optim as optim
from optim import Nadam, AdaBound
def create_optimizer(args, parameters):
if args.opt.lower() == 'sgd':
optimizer = optim.SGD(
parameters, lr=args.lr,
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
elif args.opt.lower() == 'adam':
optimizer = optim.Adam(
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'nadam':
optimizer = Nadam(
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'adabound':
optimizer = AdaBound(
parameters, lr=args.lr / 100, weight_decay=args.weight_decay, eps=args.opt_eps,
final_lr=args.lr)
elif args.opt.lower() == 'adadelta':
optimizer = optim.Adadelta(
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'rmsprop':
optimizer = optim.RMSprop(
parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps,
momentum=args.momentum, weight_decay=args.weight_decay)
else:
assert False and "Invalid optimizer"
raise ValueError
return optimizer

@ -1,4 +1,5 @@
from .cosine_lr import CosineLRScheduler
from .plateau_lr import PlateauLRScheduler
from .step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler
from scheduler.cosine_lr import CosineLRScheduler
from scheduler.plateau_lr import PlateauLRScheduler
from scheduler.step_lr import StepLRScheduler
from scheduler.tanh_lr import TanhLRScheduler
from scheduler.scheduler_factory import create_scheduler

@ -0,0 +1,43 @@
from scheduler.cosine_lr import CosineLRScheduler
from scheduler.plateau_lr import PlateauLRScheduler
from scheduler.tanh_lr import TanhLRScheduler
from scheduler.step_lr import StepLRScheduler
def create_scheduler(args, optimizer):
num_epochs = args.epochs
#FIXME expose cycle parms of the scheduler config to arguments
if args.sched == 'cosine':
lr_scheduler = CosineLRScheduler(
optimizer,
t_initial=num_epochs,
t_mul=1.0,
lr_min=1e-5,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cycle_limit=1,
t_in_epochs=True,
)
num_epochs = lr_scheduler.get_cycle_length() + 10
elif args.sched == 'tanh':
lr_scheduler = TanhLRScheduler(
optimizer,
t_initial=num_epochs,
t_mul=1.0,
lr_min=1e-5,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cycle_limit=1,
t_in_epochs=True,
)
num_epochs = lr_scheduler.get_cycle_length() + 10
else:
lr_scheduler = StepLRScheduler(
optimizer,
decay_t=args.decay_epochs,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
)
return lr_scheduler, num_epochs

@ -1,7 +1,6 @@
import argparse
import time
from collections import OrderedDict
from datetime import datetime
try:
@ -12,17 +11,14 @@ except ImportError:
has_apex = False
from data import *
from models import model_factory
from models import create_model, resume_checkpoint
from utils import *
from optim import Nadam, AdaBound
from loss import LabelSmoothingCrossEntropy
import scheduler
from optim import create_optimizer
from scheduler import create_scheduler
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torch.distributed as dist
import torchvision.utils
@ -33,6 +29,8 @@ parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
@ -120,10 +118,13 @@ def main():
r = torch.distributed.get_rank()
if args.distributed:
print('Training in distributed mode with %d processes, 1 GPU per process. Process %d.'
% (args.world_size, r))
print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (r, args.world_size))
else:
print('Training with a single process with %d GPUs.' % args.num_gpu)
print('Training with a single process on %d GPUs.' % args.num_gpu)
# FIXME seed handling for multi-process distributed?
torch.manual_seed(args.seed)
output_dir = ''
if args.local_rank == 0:
@ -137,80 +138,21 @@ def main():
str(args.img_size)])
output_dir = get_outdir(output_base, 'train', exp_name)
batch_size = args.batch_size
torch.manual_seed(args.seed)
data_mean, data_std = get_model_meanstd(args.model)
dataset_train = Dataset(os.path.join(args.data, 'train'))
loader_train = create_loader(
dataset_train,
img_size=args.img_size,
batch_size=batch_size,
is_training=True,
use_prefetcher=True,
random_erasing=0.3,
mean=data_mean,
std=data_std,
num_workers=args.workers,
distributed=args.distributed,
)
dataset_eval = Dataset(os.path.join(args.data, 'validation'))
loader_eval = create_loader(
dataset_eval,
img_size=args.img_size,
batch_size=4 * args.batch_size,
is_training=False,
use_prefetcher=True,
mean=data_mean,
std=data_std,
num_workers=args.workers,
distributed=args.distributed,
)
model = model_factory.create_model(
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=1000,
num_classes=args.num_classes,
drop_rate=args.drop,
global_pool=args.gp,
checkpoint_path=args.initial_checkpoint)
data_mean, data_std = get_mean_and_std(model, args)
# optionally resume from a checkpoint
start_epoch = 0 if args.start_epoch is None else args.start_epoch
start_epoch = 0
optimizer_state = None
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('module'):
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
if 'optimizer' in checkpoint:
optimizer_state = checkpoint['optimizer']
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
start_epoch = checkpoint['epoch'] if args.start_epoch is None else args.start_epoch
else:
model.load_state_dict(checkpoint)
else:
print("=> no checkpoint found at '{}'".format(args.resume))
return False
if args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = train_loss_fn
start_epoch, optimizer_state = resume_checkpoint(model, args.resume, args.start_epoch)
if args.num_gpu > 1:
if args.amp:
@ -237,9 +179,55 @@ def main():
model = DDP(model, delay_allreduce=True)
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
if start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
print('Scheduled epochs: ', num_epochs)
train_dir = os.path.join(args.data, 'train')
if not os.path.exists(train_dir):
print('Error: training folder does not exist at: %s' % train_dir)
exit(1)
dataset_train = Dataset(train_dir)
loader_train = create_loader(
dataset_train,
img_size=args.img_size,
batch_size=args.batch_size,
is_training=True,
use_prefetcher=True,
random_erasing=0.3,
mean=data_mean,
std=data_std,
num_workers=args.workers,
distributed=args.distributed,
)
eval_dir = os.path.join(args.data, 'validation')
if not os.path.isdir(eval_dir):
print('Error: validation folder does not exist at: %s' % eval_dir)
exit(1)
dataset_eval = Dataset(eval_dir)
loader_eval = create_loader(
dataset_eval,
img_size=args.img_size,
batch_size=4 * args.batch_size,
is_training=False,
use_prefetcher=True,
mean=data_mean,
std=data_std,
num_workers=args.workers,
distributed=args.distributed,
)
if args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = train_loss_fn
eval_metric = args.eval_metric
saver = None
if output_dir:
@ -429,76 +417,9 @@ def validate(model, loader, loss_fn, args):
return metrics
def create_optimizer(args, parameters):
if args.opt.lower() == 'sgd':
optimizer = optim.SGD(
parameters, lr=args.lr,
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
elif args.opt.lower() == 'adam':
optimizer = optim.Adam(
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'nadam':
optimizer = Nadam(
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'adabound':
optimizer = AdaBound(
parameters, lr=args.lr / 100, weight_decay=args.weight_decay, eps=args.opt_eps,
final_lr=args.lr)
elif args.opt.lower() == 'adadelta':
optimizer = optim.Adadelta(
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'rmsprop':
optimizer = optim.RMSprop(
parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps,
momentum=args.momentum, weight_decay=args.weight_decay)
else:
assert False and "Invalid optimizer"
raise ValueError
return optimizer
def create_scheduler(args, optimizer):
num_epochs = args.epochs
#FIXME expose cycle parms of the scheduler config to arguments
if args.sched == 'cosine':
lr_scheduler = scheduler.CosineLRScheduler(
optimizer,
t_initial=num_epochs,
t_mul=1.0,
lr_min=1e-5,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cycle_limit=1,
t_in_epochs=True,
)
num_epochs = lr_scheduler.get_cycle_length() + 10
elif args.sched == 'tanh':
lr_scheduler = scheduler.TanhLRScheduler(
optimizer,
t_initial=num_epochs,
t_mul=1.0,
lr_min=1e-5,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cycle_limit=1,
t_in_epochs=True,
)
num_epochs = lr_scheduler.get_cycle_length() + 10
else:
lr_scheduler = scheduler.StepLRScheduler(
optimizer,
decay_t=args.decay_epochs,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
)
return lr_scheduler, num_epochs
def reduce_tensor(tensor, n):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= n
return rt

@ -6,13 +6,14 @@ import argparse
import os
import time
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.parallel
from models import create_model, load_checkpoint, TestTimePoolHead
from data import Dataset, create_loader, get_model_meanstd
from models import create_model, apply_test_time_pool
from data import Dataset, create_loader, get_mean_and_std
from utils import accuracy, AverageMeter
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
parser.add_argument('data', metavar='DIR',
@ -25,6 +26,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=224, type=int,
metavar='N', help='Input image dimension')
parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@ -41,25 +44,19 @@ def main():
args = parser.parse_args()
# create model
num_classes = 1000
model = create_model(
args.model,
num_classes=num_classes,
pretrained=args.pretrained)
num_classes=args.num_classes,
in_chans=3,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint)
print('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
# load a checkpoint
if not args.pretrained:
if not load_checkpoint(model, args.checkpoint):
exit(1)
data_mean, data_std = get_mean_and_std(model, args)
test_time_pool = False
# FIXME make this work for networks with default img size != 224 and default pool k != 7
if args.img_size > 224 and not args.no_test_pool:
model = TestTimePoolHead(model)
test_time_pool = True
model, test_time_pool = apply_test_time_pool(model, args)
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
@ -69,14 +66,11 @@ def main():
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
cudnn.benchmark = True
data_mean, data_std = get_model_meanstd(args.model)
loader = create_loader(
Dataset(args.data),
img_size=args.img_size,
batch_size=args.batch_size,
use_prefetcher=True,
use_prefetcher=False,
mean=data_mean,
std=data_std,
num_workers=args.workers,
@ -111,51 +105,17 @@ def main():
if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
i, len(loader), batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5))
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
if __name__ == '__main__':
main()

Loading…
Cancel
Save