Uniform pretrained model handling.

* All models have 'default_cfgs' dict
* load/resume/pretrained helpers factored out
* pretrained load operates on state_dict based on default_cfg
* test all models in validate
* schedule, optim factor factored out
* test time pool wrapper applied based on default_cfg
pull/1/head
Ross Wightman 6 years ago
parent 63e677d03b
commit 9c3859fb9c

@ -1,4 +1,4 @@
from data.dataset import Dataset from data.dataset import Dataset
from data.transforms import transforms_imagenet_eval, transforms_imagenet_train, get_model_meanstd from data.transforms import *
from data.utils import create_loader from data.loader import create_loader
from data.random_erasing import RandomErasingTorch, RandomErasingNumpy from data.random_erasing import RandomErasingTorch, RandomErasingNumpy

@ -94,6 +94,9 @@ def create_loader(
sampler = None sampler = None
if distributed: if distributed:
# FIXME note, doing this for validation isn't technically correct
# There currently is no fixed order distributed sampler that corrects
# for padded entries
sampler = tdata.distributed.DistributedSampler(dataset) sampler = tdata.distributed.DistributedSampler(dataset)
loader = tdata.DataLoader( loader = tdata.DataLoader(

@ -1,8 +1,5 @@
from __future__ import absolute_import from __future__ import absolute_import
#from torchvision.transforms import *
from PIL import Image
import random import random
import math import math
import numpy as np import numpy as np

@ -7,26 +7,57 @@ from data.random_erasing import RandomErasingNumpy
DEFAULT_CROP_PCT = 0.875 DEFAULT_CROP_PCT = 0.875
IMAGENET_DPN_MEAN = [124 / 255, 117 / 255, 104 / 255] IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DPN_STD = [1 / (.0167 * 255)] * 3 IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
IMAGENET_INCEPTION_MEAN = [0.5, 0.5, 0.5] IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = [0.5, 0.5, 0.5] IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
def get_mean_and_std(model, args, num_chan=3):
if hasattr(model, 'default_cfg'):
mean = model.default_cfg['mean']
std = model.default_cfg['std']
else:
if args.mean is not None:
mean = tuple(args.mean)
if len(mean) == 1:
mean = tuple(list(mean) * num_chan)
else:
assert len(mean) == num_chan
else:
mean = get_mean_by_model(args.model)
if args.std is not None:
std = tuple(args.std)
if len(std) == 1:
std = tuple(list(std) * num_chan)
else:
assert len(std) == num_chan
else:
std = get_std_by_model(args.model)
return mean, std
def get_mean_by_name(name):
if name == 'dpn':
return IMAGENET_DPN_MEAN
elif name == 'inception' or name == 'le':
return IMAGENET_INCEPTION_MEAN
else:
return IMAGENET_DEFAULT_MEAN
# FIXME replace these mean/std fn with model factory based values from config dict def get_std_by_name(name):
def get_model_meanstd(model_name): if name == 'dpn':
model_name = model_name.lower() return IMAGENET_DPN_STD
if 'dpn' in model_name: elif name == 'inception' or name == 'le':
return IMAGENET_DPN_MEAN, IMAGENET_DPN_STD return IMAGENET_INCEPTION_STD
elif 'ception' in model_name or 'nasnet' in model_name:
return IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
else: else:
return IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD return IMAGENET_DEFAULT_STD
def get_model_mean(model_name): def get_mean_by_model(model_name):
model_name = model_name.lower() model_name = model_name.lower()
if 'dpn' in model_name: if 'dpn' in model_name:
return IMAGENET_DPN_STD return IMAGENET_DPN_STD
@ -36,7 +67,7 @@ def get_model_mean(model_name):
return IMAGENET_DEFAULT_MEAN return IMAGENET_DEFAULT_MEAN
def get_model_std(model_name): def get_std_by_model(model_name):
model_name = model_name.lower() model_name = model_name.lower()
if 'dpn' in model_name: if 'dpn' in model_name:
return IMAGENET_DEFAULT_STD return IMAGENET_DEFAULT_STD
@ -93,8 +124,8 @@ def transforms_imagenet_train(
tfl += [ tfl += [
ToTensor(), ToTensor(),
transforms.Normalize( transforms.Normalize(
mean=torch.tensor(mean) * 255, mean=torch.tensor(mean),
std=torch.tensor(std) * 255) std=torch.tensor(std))
] ]
if random_erasing > 0.: if random_erasing > 0.:
tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True)) tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True))
@ -124,11 +155,5 @@ def transforms_imagenet_eval(
mean=torch.tensor(mean), mean=torch.tensor(mean),
std=torch.tensor(std)) std=torch.tensor(std))
] ]
# tfl += [
# ToTensor(),
# transforms.Normalize(
# mean=torch.tensor(mean) * 255,
# std=torch.tensor(std) * 255)
# ]
return transforms.Compose(tfl) return transforms.Compose(tfl)

@ -11,10 +11,11 @@ import argparse
import numpy as np import numpy as np
import torch import torch
from models import create_model, load_checkpoint, TestTimePoolHead from models import create_model, apply_test_time_pool
from data import Dataset, create_loader, get_model_meanstd from data import Dataset, create_loader, get_mean_and_std
from utils import AverageMeter from utils import AverageMeter
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference') parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
parser.add_argument('data', metavar='DIR', parser.add_argument('data', metavar='DIR',
@ -29,6 +30,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)') metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=224, type=int, parser.add_argument('--img-size', default=224, type=int,
metavar='N', help='Input image dimension') metavar='N', help='Input image dimension')
parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset')
parser.add_argument('--print-freq', '-p', default=10, type=int, parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)') metavar='N', help='print frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@ -45,26 +48,24 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# create model # create model
num_classes = 1000
model = create_model( model = create_model(
args.model, args.model,
num_classes=num_classes, num_classes=args.num_classes,
pretrained=args.pretrained) in_chans=3,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint)
print('Model %s created, param count: %d' % print('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()]))) (args.model, sum([m.numel() for m in model.parameters()])))
# load a checkpoint data_mean, data_std = get_mean_and_std(model, args)
if not args.pretrained: model, test_time_pool = apply_test_time_pool(model, args)
if not load_checkpoint(model, args.checkpoint):
exit(1)
if args.num_gpu > 1: if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else: else:
model = model.cuda() model = model.cuda()
data_mean, data_std = get_model_meanstd(args.model)
loader = create_loader( loader = create_loader(
Dataset(args.data), Dataset(args.data),
img_size=args.img_size, img_size=args.img_size,

@ -1,3 +1,4 @@
from .model_factory import create_model, load_checkpoint from models.model_factory import create_model
from .test_time_pool import TestTimePoolHead from models.helpers import load_checkpoint, resume_checkpoint
from models.test_time_pool import TestTimePoolHead, apply_test_time_pool

@ -5,19 +5,29 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool.
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from collections import OrderedDict from collections import OrderedDict
from .adaptive_avgmax_pool import *
from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import *
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import re import re
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
model_urls = { def _cfg(url=''):
'densenet121': 'https://download.pytorch.org/models/densenet121-241335ed.pth', return {
'densenet169': 'https://download.pytorch.org/models/densenet169-6f0f7f60.pth', 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 244), 'pool_size': (7, 7),
'densenet201': 'https://download.pytorch.org/models/densenet201-4c113574.pth', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'densenet161': 'https://download.pytorch.org/models/densenet161-17b70270.pth', 'first_conv': 'features.conv0', 'classifier': 'classifier',
}
default_cfgs = {
'densenet121': _cfg(url='https://download.pytorch.org/models/densenet121-241335ed.pth'),
'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-6f0f7f60.pth'),
'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-4c113574.pth'),
'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-17b70270.pth'),
} }
@ -34,59 +44,56 @@ def _filter_pretrained(state_dict):
return state_dict return state_dict
def densenet121(pretrained=False, **kwargs): def densenet121(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
r"""Densenet-121 model from r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) default_cfg = default_cfgs['densenet121']
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
state_dict = model_zoo.load_url(model_urls['densenet121']) load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
model.load_state_dict(_filter_pretrained(state_dict))
return model return model
def densenet169(pretrained=False, **kwargs): def densenet169(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
r"""Densenet-169 model from r"""Densenet-169 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs) default_cfg = default_cfgs['densenet169']
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
state_dict = model_zoo.load_url(model_urls['densenet169']) load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
model.load_state_dict(_filter_pretrained(state_dict))
return model return model
def densenet201(pretrained=False, **kwargs): def densenet201(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
r"""Densenet-201 model from r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs) default_cfg = default_cfgs['densenet201']
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
state_dict = model_zoo.load_url(model_urls['densenet201']) load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
model.load_state_dict(_filter_pretrained(state_dict))
return model return model
def densenet161(pretrained=False, **kwargs): def densenet161(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
r"""Densenet-201 model from r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs) print(num_classes, in_chans, pretrained)
default_cfg = default_cfgs['densenet161']
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
state_dict = model_zoo.load_url(model_urls['densenet161']) load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
model.load_state_dict(_filter_pretrained(state_dict))
return model return model
@ -142,14 +149,15 @@ class DenseNet(nn.Module):
num_classes (int) - number of classification classes num_classes (int) - number of classification classes
""" """
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, global_pool='avg'): num_init_features=64, bn_size=4, drop_rate=0,
num_classes=1000, in_chans=3, global_pool='avg'):
self.global_pool = global_pool self.global_pool = global_pool
self.num_classes = num_classes self.num_classes = num_classes
super(DenseNet, self).__init__() super(DenseNet, self).__init__()
# First convolution # First convolution
self.features = nn.Sequential(OrderedDict([ self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), ('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
('norm0', nn.BatchNorm2d(num_init_features)), ('norm0', nn.BatchNorm2d(num_init_features)),
('relu0', nn.ReLU(inplace=True)), ('relu0', nn.ReLU(inplace=True)),
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
@ -172,7 +180,7 @@ class DenseNet(nn.Module):
self.features.add_module('norm5', nn.BatchNorm2d(num_features)) self.features.add_module('norm5', nn.BatchNorm2d(num_features))
# Linear layer # Linear layer
self.classifier = torch.nn.Linear(num_features, num_classes) self.classifier = nn.Linear(num_features, num_classes)
self.num_features = num_features self.num_features = num_features
@ -184,7 +192,7 @@ class DenseNet(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
del self.classifier del self.classifier
if num_classes: if num_classes:
self.classifier = torch.nn.Linear(self.num_features, num_classes) self.classifier = nn.Linear(self.num_features, num_classes)
else: else:
self.classifier = None self.classifier = None

@ -13,94 +13,108 @@ import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from collections import OrderedDict from collections import OrderedDict
from .adaptive_avgmax_pool import select_adaptive_pool2d from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import select_adaptive_pool2d
from data.transforms import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
__all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107'] __all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107']
model_urls = { def _cfg(url=''):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD,
'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier',
}
default_cfgs = {
'dpn68': 'dpn68':
'http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth', _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth'),
'dpn68b_extra': 'dpn68b_extra':
'http://data.lip6.fr/cadene/pretrainedmodels/' _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68b_extra-84854c156.pth'),
'dpn68b_extra-84854c156.pth',
'dpn92': '',
'dpn92_extra': 'dpn92_extra':
'http://data.lip6.fr/cadene/pretrainedmodels/' _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'),
'dpn92_extra-b040e4a9b.pth',
'dpn98': 'dpn98':
'http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth', _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth'),
'dpn131': 'dpn131':
'http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth', _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth'),
'dpn107_extra': 'dpn107_extra':
'http://data.lip6.fr/cadene/pretrainedmodels/' _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn107_extra-1ac7121e2.pth')
'dpn107_extra-1ac7121e2.pth'
} }
def dpn68(num_classes=1000, pretrained=False): def dpn68(num_classes=1000, in_chans=3, pretrained=False):
default_cfg = default_cfgs['dpn68']
model = DPN( model = DPN(
small=True, num_init_features=10, k_r=128, groups=32, small=True, num_init_features=10, k_r=128, groups=32,
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64),
num_classes=num_classes) num_classes=num_classes, in_chans=in_chans)
model.default_cfg = default_cfg
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['dpn68'])) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def dpn68b(num_classes=1000, pretrained=False): def dpn68b(num_classes=1000, in_chans=3, pretrained=False):
default_cfg = default_cfgs['dpn68b_extra']
model = DPN( model = DPN(
small=True, num_init_features=10, k_r=128, groups=32, small=True, num_init_features=10, k_r=128, groups=32,
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64),
num_classes=num_classes) num_classes=num_classes, in_chans=in_chans)
model.default_cfg = default_cfg
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['dpn68b_extra'])) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def dpn92(num_classes=1000, pretrained=False, extra=True): def dpn92(num_classes=1000, in_chans=3, pretrained=False):
default_cfg = default_cfgs['dpn92_extra']
model = DPN( model = DPN(
num_init_features=64, k_r=96, groups=32, num_init_features=64, k_r=96, groups=32,
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
num_classes=num_classes) num_classes=num_classes, in_chans=in_chans)
model.default_cfg = default_cfg
if pretrained: if pretrained:
if extra: load_pretrained(model, default_cfg, num_classes, in_chans)
model.load_state_dict(model_zoo.load_url(model_urls['dpn92_extra']))
else:
model.load_state_dict(model_zoo.load_url(model_urls['dpn92']))
return model return model
def dpn98(num_classes=1000, pretrained=False): def dpn98(num_classes=1000, in_chans=3, pretrained=False):
default_cfg = default_cfgs['dpn98']
model = DPN( model = DPN(
num_init_features=96, k_r=160, groups=40, num_init_features=96, k_r=160, groups=40,
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128),
num_classes=num_classes) num_classes=num_classes, in_chans=in_chans)
model.default_cfg = default_cfg
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['dpn98'])) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def dpn131(num_classes=1000, pretrained=False): def dpn131(num_classes=1000, in_chans=3, pretrained=False):
default_cfg = default_cfgs['dpn131']
model = DPN( model = DPN(
num_init_features=128, k_r=160, groups=40, num_init_features=128, k_r=160, groups=40,
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128),
num_classes=num_classes) num_classes=num_classes, in_chans=in_chans)
model.default_cfg = default_cfg
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['dpn131'])) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def dpn107(num_classes=1000, pretrained=False): def dpn107(num_classes=1000, in_chans=3, pretrained=False):
default_cfg = default_cfgs['dpn107_extra']
model = DPN( model = DPN(
num_init_features=128, k_r=200, groups=50, num_init_features=128, k_r=200, groups=50,
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128),
num_classes=num_classes) num_classes=num_classes, in_chans=in_chans)
model.default_cfg = default_cfg
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['dpn107_extra'])) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
@ -128,11 +142,11 @@ class BnActConv2d(nn.Module):
class InputBlock(nn.Module): class InputBlock(nn.Module):
def __init__(self, num_init_features, kernel_size=7, def __init__(self, num_init_features, kernel_size=7, in_chans=3,
padding=3, activation_fn=nn.ReLU(inplace=True)): padding=3, activation_fn=nn.ReLU(inplace=True)):
super(InputBlock, self).__init__() super(InputBlock, self).__init__()
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
3, num_init_features, kernel_size=kernel_size, stride=2, padding=padding, bias=False) in_chans, num_init_features, kernel_size=kernel_size, stride=2, padding=padding, bias=False)
self.bn = nn.BatchNorm2d(num_init_features, eps=0.001) self.bn = nn.BatchNorm2d(num_init_features, eps=0.001)
self.act = activation_fn self.act = activation_fn
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@ -212,7 +226,7 @@ class DualPathBlock(nn.Module):
class DPN(nn.Module): class DPN(nn.Module):
def __init__(self, small=False, num_init_features=64, k_r=96, groups=32, def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
num_classes=1000, fc_act=nn.ELU(inplace=True)): num_classes=1000, in_chans=3, fc_act=nn.ELU(inplace=True)):
super(DPN, self).__init__() super(DPN, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.b = b self.b = b
@ -222,9 +236,11 @@ class DPN(nn.Module):
# conv1 # conv1
if small: if small:
blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=3, padding=1) blocks['conv1_1'] = InputBlock(
num_init_features, in_chans=in_chans, kernel_size=3, padding=1)
else: else:
blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=7, padding=3) blocks['conv1_1'] = InputBlock(
num_init_features, in_chans=in_chans, kernel_size=7, padding=3)
# conv2 # conv2
bw = 64 * bw_factor bw = 64 * bw_factor

@ -0,0 +1,89 @@
import torch
import torch.utils.model_zoo as model_zoo
import os
from collections import OrderedDict
def load_checkpoint(model, checkpoint_path):
if checkpoint_path and os.path.isfile(checkpoint_path):
print("=> Loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('module'):
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(checkpoint)
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
else:
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def resume_checkpoint(model, checkpoint_path, start_epoch=None):
start_epoch = 0 if start_epoch is None else start_epoch
optimizer_state = None
if os.path.isfile(checkpoint_path):
print("=> loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('module'):
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
if 'optimizer' in checkpoint:
optimizer_state = checkpoint['optimizer']
print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
else:
model.load_state_dict(checkpoint)
return optimizer_state, start_epoch
else:
print("=> No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=None):
state_dict = model_zoo.load_url(default_cfg['url'])
if in_chans == 1:
conv1_name = default_cfg['first_conv']
print('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
conv1_weight = state_dict[conv1_name + '.weight']
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
elif in_chans != 3:
assert False, "Invalid in_chans for pretrained weights"
strict = True
classifier_name = default_cfg['classifier']
if num_classes == 1000 and default_cfg['num_classes'] == 1001:
# special case for imagenet trained models with extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
elif num_classes != default_cfg['num_classes']:
# completely discard fully connected for all other differences between pretrained and created model
del state_dict[classifier_name + '.weight']
del state_dict[classifier_name + '.bias']
strict = False
if filter_fn is not None:
state_dict = filter_fn(state_dict)
model.load_state_dict(state_dict, strict=strict)

@ -5,12 +5,18 @@ based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo from models.helpers import load_pretrained
import numpy as np from models.adaptive_avgmax_pool import *
from .adaptive_avgmax_pool import * from data.transforms import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
model_urls = {
'imagenet': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth' default_cfgs = {
'inception_resnet_v2': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'conv2d_1a.conv', 'classifier': 'last_linear',
}
} }
@ -204,12 +210,14 @@ class Block8(nn.Module):
class InceptionResnetV2(nn.Module): class InceptionResnetV2(nn.Module):
def __init__(self, num_classes=1001, drop_rate=0., global_pool='avg'): def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'):
super(InceptionResnetV2, self).__init__() super(InceptionResnetV2, self).__init__()
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.global_pool = global_pool self.global_pool = global_pool
self.num_classes = num_classes self.num_classes = num_classes
self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) self.num_features = 1536
self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.maxpool_3a = nn.MaxPool2d(3, stride=2) self.maxpool_3a = nn.MaxPool2d(3, stride=2)
@ -265,29 +273,21 @@ class InceptionResnetV2(nn.Module):
Block8(scale=0.20) Block8(scale=0.20)
) )
self.block8 = Block8(noReLU=True) self.block8 = Block8(noReLU=True)
self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1) self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
self.num_features = 1536 self.last_linear = nn.Linear(self.num_features, num_classes)
self.last_linear = nn.Linear(1536, num_classes)
def get_classifier(self): def get_classifier(self):
return self.classif return self.last_linear
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = global_pool self.global_pool = global_pool
self.num_classes = num_classes self.num_classes = num_classes
del self.classif del self.last_linear
if num_classes: if num_classes:
self.last_linear = torch.nn.Linear(1536, num_classes) self.last_linear = torch.nn.Linear(self.num_features, num_classes)
else: else:
self.last_linear = None self.last_linear = None
def trim_classifier(self, trim=1):
self.num_classes -= trim
new_last_linear = nn.Linear(1536, self.num_classes)
new_last_linear.weight.data = self.last_linear.weight.data[trim:]
new_last_linear.bias.data = self.last_linear.bias.data[trim:]
self.last_linear = new_last_linear
def forward_features(self, x, pool=True): def forward_features(self, x, pool=True):
x = self.conv2d_1a(x) x = self.conv2d_1a(x)
x = self.conv2d_2a(x) x = self.conv2d_2a(x)
@ -318,19 +318,15 @@ class InceptionResnetV2(nn.Module):
return x return x
def inception_resnet_v2(pretrained=False, num_classes=1000, **kwargs): def inception_resnet_v2(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
r"""InceptionResnetV2 model architecture from the r"""InceptionResnetV2 model architecture from the
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper. `"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
Args:
pretrained ('string'): If True, returns a model pre-trained on ImageNet
""" """
extra_class = 1 if pretrained else 0 default_cfg = default_cfgs['inception_resnet_v2']
model = InceptionResnetV2(num_classes=num_classes + extra_class, **kwargs) model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
print('Loading pretrained from %s' % model_urls['imagenet']) load_pretrained(model, default_cfg, num_classes, in_chans)
model.load_state_dict(model_zoo.load_url(model_urls['imagenet']))
model.trim_classifier()
return model return model

@ -5,11 +5,18 @@ based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo from models.helpers import load_pretrained
from .adaptive_avgmax_pool import * from models.adaptive_avgmax_pool import *
from data.transforms import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
model_urls = {
'imagenet': 'http://webia.lip6.fr/~cadene/Downloads/inceptionv4-97ef9c30.pth'
default_cfgs = {
'inception_v4': {
'url': 'http://webia.lip6.fr/~cadene/Downloads/inceptionv4-97ef9c30.pth',
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'features.0.conv', 'classifier': 'classif',
}
} }
@ -230,13 +237,15 @@ class Inception_C(nn.Module):
class InceptionV4(nn.Module): class InceptionV4(nn.Module):
def __init__(self, num_classes=1001, drop_rate=0., global_pool='avg'): def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'):
super(InceptionV4, self).__init__() super(InceptionV4, self).__init__()
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.global_pool = global_pool self.global_pool = global_pool
self.num_classes = num_classes self.num_classes = num_classes
self.num_features = 1536
self.features = nn.Sequential( self.features = nn.Sequential(
BasicConv2d(3, 32, kernel_size=3, stride=2), BasicConv2d(in_chans, 32, kernel_size=3, stride=2),
BasicConv2d(32, 32, kernel_size=3, stride=1), BasicConv2d(32, 32, kernel_size=3, stride=1),
BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1),
Mixed_3a(), Mixed_3a(),
@ -259,7 +268,7 @@ class InceptionV4(nn.Module):
Inception_C(), Inception_C(),
Inception_C(), Inception_C(),
) )
self.classif = nn.Linear(1536, num_classes) self.classif = nn.Linear(self.num_features, num_classes)
def get_classifier(self): def get_classifier(self):
return self.classif return self.classif
@ -267,12 +276,12 @@ class InceptionV4(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = global_pool self.global_pool = global_pool
self.num_classes = num_classes self.num_classes = num_classes
self.classif = nn.Linear(1536, num_classes) self.classif = nn.Linear(self.num_features, num_classes)
def forward_features(self, x, pool=True): def forward_features(self, x, pool=True):
x = self.features(x) x = self.features(x)
if pool: if pool:
x = select_adaptive_pool2d(x, self.global_pool, count_include_pad=False) x = select_adaptive_pool2d(x, self.global_pool)
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
return x return x
@ -284,10 +293,12 @@ class InceptionV4(nn.Module):
return x return x
def inception_v4(pretrained=False, num_classes=1001, **kwargs): def inception_v4(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
model = InceptionV4(num_classes=num_classes, **kwargs) default_cfg = default_cfgs['inception_v4']
model = InceptionV4(num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['imagenet'])) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model

@ -1,155 +1,34 @@
import torch from models.inception_v4 import inception_v4
import os from models.inception_resnet_v2 import inception_resnet_v2
from collections import OrderedDict from models.densenet import densenet161, densenet121, densenet169, densenet201
from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152, \
from .inception_v4 import inception_v4
from .inception_resnet_v2 import inception_resnet_v2
from .densenet import densenet161, densenet121, densenet169, densenet201
from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152, \
resnext50_32x4d, resnext101_32x4d, resnext101_64x4d, resnext152_32x4d resnext50_32x4d, resnext101_32x4d, resnext101_64x4d, resnext152_32x4d
from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107 from models.dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107
from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \ from models.senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \
seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d
#from .resnext import resnext50, resnext101, resnext152 from models.xception import xception
from .xception import xception from models.pnasnet import pnasnet5large
from .pnasnet import pnasnet5large
model_config_dict = { from models.helpers import load_checkpoint
'resnet18': {
'model_name': 'resnet18', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'resnet34': {
'model_name': 'resnet34', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'resnet50': {
'model_name': 'resnet50', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'resnet101': {
'model_name': 'resnet101', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'resnet152': {
'model_name': 'resnet152', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'densenet121': {
'model_name': 'densenet121', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'densenet169': {
'model_name': 'densenet169', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'densenet201': {
'model_name': 'densenet201', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'densenet161': {
'model_name': 'densenet161', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
'dpn107': {
'model_name': 'dpn107', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'dpn92_extra': {
'model_name': 'dpn92', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'dpn92': {
'model_name': 'dpn92', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'dpn68': {
'model_name': 'dpn68', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'dpn68b': {
'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'dpn68b_extra': {
'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'inception_resnet_v2': {
'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
'xception': {
'model_name': 'xception', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
'pnasnet5large': {
'model_name': 'pnasnet5large', 'num_classes': 1000, 'input_size': 331, 'normalizer': 'le'}
}
def create_model( def create_model(
model_name='resnet50', model_name='resnet50',
pretrained=None, pretrained=None,
num_classes=1000, num_classes=1000,
in_chans=3,
checkpoint_path='', checkpoint_path='',
**kwargs): **kwargs):
if model_name == 'dpn68': margs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained)
model = dpn68(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'dpn68b': if model_name in globals():
model = dpn68b(num_classes=num_classes, pretrained=pretrained) create_fn = globals()[model_name]
elif model_name == 'dpn92': model = create_fn(**margs, **kwargs)
model = dpn92(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'dpn98':
model = dpn98(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'dpn131':
model = dpn131(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'dpn107':
model = dpn107(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'resnet18':
model = resnet18(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnet34':
model = resnet34(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnet50':
model = resnet50(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnet101':
model = resnet101(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnet152':
model = resnet152(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'densenet121':
model = densenet121(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'densenet161':
model = densenet161(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'densenet169':
model = densenet169(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'densenet201':
model = densenet201(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'inception_resnet_v2':
model = inception_resnet_v2(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'inception_v4':
model = inception_v4(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet18':
model = seresnet18(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet34':
model = seresnet34(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet50':
model = seresnet50(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet101':
model = seresnet101(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnet152':
model = seresnet152(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnext26_32x4d':
model = seresnext26_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnext50_32x4d':
model = seresnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'seresnext101_32x4d':
model = seresnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext50_32x4d':
model = resnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext101_32x4d':
model = resnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext101_64x4d':
model = resnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'resnext152_32x4d':
model = resnext152_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
elif model_name == 'xception':
model = xception(num_classes=num_classes, pretrained=pretrained)
elif model_name == 'pnasnet5large':
model = pnasnet5large(num_classes=num_classes, pretrained=pretrained)
else: else:
assert False and "Invalid model" raise RuntimeError('Unknown model (%s)' % model_name)
if checkpoint_path and not pretrained: if checkpoint_path and not pretrained:
print(checkpoint_path)
load_checkpoint(model, checkpoint_path) load_checkpoint(model, checkpoint_path)
return model return model
def load_checkpoint(model, checkpoint_path):
if checkpoint_path and os.path.isfile(checkpoint_path):
print("=> Loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('module'):
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(checkpoint)
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
return True
else:
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
return False

@ -3,29 +3,23 @@ from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.model_zoo as model_zoo import torch.nn.functional as F
pretrained_settings = { from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
default_cfgs = {
'pnasnet5large': { 'pnasnet5large': {
'imagenet': { 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth', 'input_size': (3, 331, 331),
'input_space': 'RGB', 'pool_size': (11, 11),
'input_size': [3, 331, 331], 'mean': (0.5, 0.5, 0.5),
'input_range': [0, 1], 'std': (0.5, 0.5, 0.5),
'mean': [0.5, 0.5, 0.5], 'crop_pct': 0.8975,
'std': [0.5, 0.5, 0.5], 'num_classes': 1001,
'num_classes': 1000 'first_conv': 'conv_0.conv',
}, 'classifier': 'last_linear',
'imagenet+background': { },
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',
'input_space': 'RGB',
'input_size': [3, 331, 331],
'input_range': [0, 1],
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'num_classes': 1001
}
}
} }
@ -288,13 +282,14 @@ class Cell(CellBase):
class PNASNet5Large(nn.Module): class PNASNet5Large(nn.Module):
def __init__(self, num_classes=1001): def __init__(self, num_classes=1001, in_chans=3, drop_rate=0.5, global_pool='avg'):
super(PNASNet5Large, self).__init__() super(PNASNet5Large, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.num_features = 4320 self.num_features = 4320
self.drop_rate = drop_rate
self.conv_0 = nn.Sequential(OrderedDict([ self.conv_0 = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(3, 96, kernel_size=3, stride=2, bias=False)), ('conv', nn.Conv2d(in_chans, 96, kernel_size=3, stride=2, bias=False)),
('bn', nn.BatchNorm2d(96, eps=0.001)) ('bn', nn.BatchNorm2d(96, eps=0.001))
])) ]))
self.cell_stem_0 = CellStem0(in_channels_left=96, out_channels_left=54, self.cell_stem_0 = CellStem0(in_channels_left=96, out_channels_left=54,
@ -334,18 +329,18 @@ class PNASNet5Large(nn.Module):
self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864, self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864,
in_channels_right=4320, out_channels_right=864) in_channels_right=4320, out_channels_right=864)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.dropout = nn.Dropout(0.5) self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
self.last_linear = nn.Linear(self.num_features, num_classes)
def get_classifier(self): def get_classifier(self):
return self.last_linear return self.last_linear
def reset_classifier(self, num_classes): def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
del self.last_linear del self.last_linear
if num_classes: if num_classes:
self.last_linear = nn.Linear(self.num_features, num_classes) self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
else: else:
self.last_linear = None self.last_linear = None
@ -367,38 +362,27 @@ class PNASNet5Large(nn.Module):
x_cell_11 = self.cell_11(x_cell_9, x_cell_10) x_cell_11 = self.cell_11(x_cell_9, x_cell_10)
x = self.relu(x_cell_11) x = self.relu(x_cell_11)
if pool: if pool:
x = self.avg_pool(x) x = self.global_pool(x)
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
return x return x
def forward(self, input): def forward(self, input):
x = self.forward_features(input) x = self.forward_features(input)
x = self.dropout(x) if self.drop_rate > 0:
x = F.dropout(x, self.drop_rate, training=self.training)
x = self.last_linear(x) x = self.last_linear(x)
return x return x
def pnasnet5large(num_classes=1001, pretrained='imagenet'): def pnasnet5large(num_classes=1000, in_chans=3, pretrained='imagenet'):
r"""PNASNet-5 model architecture from the r"""PNASNet-5 model architecture from the
`"Progressive Neural Architecture Search" `"Progressive Neural Architecture Search"
<https://arxiv.org/abs/1712.00559>`_ paper. <https://arxiv.org/abs/1712.00559>`_ paper.
""" """
default_cfg = default_cfgs['pnasnet5large']
model = PNASNet5Large(num_classes=1000, in_chans=in_chans)
model.default_cfg = default_cfg
if pretrained: if pretrained:
settings = pretrained_settings['pnasnet5large']['imagenet'] load_pretrained(model, default_cfg, num_classes, in_chans)
assert num_classes == settings[
'num_classes'], 'num_classes should be {}, but is {}'.format(
settings['num_classes'], num_classes)
# both 'imagenet'&'imagenet+background' are loaded from same parameters
model = PNASNet5Large(num_classes=1001)
model.load_state_dict(model_zoo.load_url(settings['url']))
#if pretrained == 'imagenet':
new_last_linear = nn.Linear(model.last_linear.in_features, 1000)
new_last_linear.weight.data = model.last_linear.weight.data[1:]
new_last_linear.bias.data = model.last_linear.bias.data[1:]
model.last_linear = new_last_linear
else:
model = PNASNet5Large(num_classes=num_classes)
return model return model

@ -6,17 +6,33 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math import math
import torch.utils.model_zoo as model_zoo from models.helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d from models.adaptive_avgmax_pool import SelectAdaptivePool2d
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
model_urls = { 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d']
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', def _cfg(url=''):
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', return {
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'crop_pct': 0.875,
'first_conv': 'conv1', 'classifier': 'fc',
}
default_cfgs = {
'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'),
'resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'),
'resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'),
'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'),
'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'),
'resnext50_32x4d': _cfg(url=''),
'resnext101_32x4d': _cfg(url=''),
'resnext101_64x4d': _cfg(url=''),
'resnext152_32x4d': _cfg(url=''),
} }
@ -116,7 +132,7 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module): class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, def __init__(self, block, layers, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, cardinality=1, base_width=64,
drop_rate=0.0, block_drop_rate=0.0, drop_rate=0.0, block_drop_rate=0.0,
global_pool='avg'): global_pool='avg'):
@ -127,7 +143,7 @@ class ResNet(nn.Module):
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.expansion = block.expansion self.expansion = block.expansion
super(ResNet, self).__init__() super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@ -197,109 +213,108 @@ class ResNet(nn.Module):
return x return x
def resnet18(pretrained=False, **kwargs): def resnet18(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNet-18 model. """Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) default_cfg = default_cfgs['resnet18']
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def resnet34(pretrained=False, **kwargs): def resnet34(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNet-34 model. """Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) default_cfg = default_cfgs['resnet34']
model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def resnet50(pretrained=False, **kwargs): def resnet50(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNet-50 model. """Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) default_cfg = default_cfgs['resnet50']
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def resnet101(pretrained=False, **kwargs): def resnet101(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNet-101 model. """Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) default_cfg = default_cfgs['resnet101']
model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def resnet152(pretrained=False, **kwargs): def resnet152(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNet-152 model. """Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) default_cfg = default_cfgs['resnet152']
model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def resnext50_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs): def resnext50_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNeXt50-32x4d model. """Constructs a ResNeXt50-32x4d model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
""" """
default_cfg = default_cfgs['resnext50_32x4d2']
model = ResNet( model = ResNet(
Bottleneck, [3, 4, 6, 3], cardinality=cardinality, base_width=base_width, **kwargs) Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def resnext101_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs): def resnext101_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNeXt-101 model. """Constructs a ResNeXt-101 model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
""" """
default_cfg = default_cfgs['resnext101_32x4d']
model = ResNet( model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs) Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def resnext101_64x4d(cardinality=64, base_width=4, pretrained=False, **kwargs): def resnext101_64x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNeXt101-64x4d model. """Constructs a ResNeXt101-64x4d model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
""" """
default_cfg = default_cfgs['resnext101_32x4d']
model = ResNet( model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs) Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def resnext152_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs): def resnext152_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNeXt152-32x4d model. """Constructs a ResNeXt152-32x4d model.
Args:
cardinality (int): Cardinality of the aggregated transform
base_width (int): Base width of the grouped convolution
""" """
default_cfg = default_cfgs['resnext152_32x4d']
model = ResNet( model = ResNet(
Bottleneck, [3, 8, 36, 3], cardinality=cardinality, base_width=base_width, **kwargs) Bottleneck, [3, 8, 36, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model

@ -8,21 +8,40 @@ import math
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils import model_zoo
from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import SelectAdaptivePool2d from models.adaptive_avgmax_pool import SelectAdaptivePool2d
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152', __all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152',
'seresnext50_32x4d', 'seresnext101_32x4d'] 'seresnext50_32x4d', 'seresnext101_32x4d']
model_urls = {
'senet154': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth', def _cfg(url=''):
'seresnet18': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', return {
'seresnet34': 'https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1', 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 244), 'pool_size': (7, 7),
'seresnet50': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'crop_pct': 0.875,
'seresnet101': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth', 'first_conv': 'layer0.conv1', 'classifier': 'last_linear',
'seresnet152': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth', }
'seresnext50_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
'seresnext101_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
default_cfgs = {
'senet154':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'),
'seresnet18':
_cfg(url=''),
'seresnet34':
_cfg(url='https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1'),
'seresnet50':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth'),
'seresnet101':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth'),
'seresnet152':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth'),
'seresnext50_32x4d':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'),
'seresnext101_32x4d':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth'),
} }
@ -197,7 +216,7 @@ class SEResNetBlock(nn.Module):
class SENet(nn.Module): class SENet(nn.Module):
def __init__(self, block, layers, groups, reduction, drop_rate=0.2, def __init__(self, block, layers, groups, reduction, drop_rate=0.2,
inchans=3, inplanes=128, input_3x3=True, downsample_kernel_size=3, in_chans=3, inplanes=128, input_3x3=True, downsample_kernel_size=3,
downsample_padding=1, num_classes=1000, global_pool='avg'): downsample_padding=1, num_classes=1000, global_pool='avg'):
""" """
Parameters Parameters
@ -247,7 +266,7 @@ class SENet(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
if input_3x3: if input_3x3:
layer0_modules = [ layer0_modules = [
('conv1', nn.Conv2d(inchans, 64, 3, stride=2, padding=1, bias=False)), ('conv1', nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False)),
('bn1', nn.BatchNorm2d(64)), ('bn1', nn.BatchNorm2d(64)),
('relu1', nn.ReLU(inplace=True)), ('relu1', nn.ReLU(inplace=True)),
('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)), ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)),
@ -260,7 +279,7 @@ class SENet(nn.Module):
else: else:
layer0_modules = [ layer0_modules = [
('conv1', nn.Conv2d( ('conv1', nn.Conv2d(
inchans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)), in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
('bn1', nn.BatchNorm2d(inplanes)), ('bn1', nn.BatchNorm2d(inplanes)),
('relu1', nn.ReLU(inplace=True)), ('relu1', nn.ReLU(inplace=True)),
] ]
@ -368,99 +387,107 @@ class SENet(nn.Module):
return x return x
def _load_pretrained(model, url, inchans=3): def seresnet18(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
state_dict = model_zoo.load_url(url) default_cfg = default_cfgs['seresnet18']
if inchans == 1:
conv1_weight = state_dict['conv1.weight']
state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True)
elif inchans != 3:
assert False, "Invalid inchans for pretrained weights"
model.load_state_dict(state_dict)
def senet154(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
num_classes=num_classes, **kwargs)
if pretrained:
_load_pretrained(model, model_urls['senet154'], inchans)
return model
def seresnet18(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16, model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16,
inplanes=64, input_3x3=False, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0, downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
_load_pretrained(model, model_urls['seresnet18'], inchans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def seresnet34(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs): def seresnet34(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['seresnet34']
model = SENet(SEResNetBlock, [3, 4, 6, 3], groups=1, reduction=16, model = SENet(SEResNetBlock, [3, 4, 6, 3], groups=1, reduction=16,
inplanes=64, input_3x3=False, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0, downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
_load_pretrained(model, model_urls['seresnet34'], inchans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def seresnet50(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs): def seresnet50(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['seresnet50']
model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16, model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
inplanes=64, input_3x3=False, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0, downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
_load_pretrained(model, model_urls['seresnet50'], inchans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def seresnet101(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs): def seresnet101(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['seresnet101']
model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16, model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
inplanes=64, input_3x3=False, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0, downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
_load_pretrained(model, model_urls['seresnet101'], inchans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def seresnet152(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs): def seresnet152(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['seresnet152']
model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16, model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
inplanes=64, input_3x3=False, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0, downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def senet154(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['senet154']
model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
_load_pretrained(model, model_urls['seresnet152'], inchans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def seresnext26_32x4d(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs): def seresnext26_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['seresnext26_32x4d']
model = SENet(SEResNeXtBottleneck, [2, 2, 2, 2], groups=32, reduction=16, model = SENet(SEResNeXtBottleneck, [2, 2, 2, 2], groups=32, reduction=16,
inplanes=64, input_3x3=False, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0, downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
_load_pretrained(model, model_urls['se_resnext26_32x4d'], inchans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def seresnext50_32x4d(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs): def seresnext50_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['seresnext50_32x4d']
model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
inplanes=64, input_3x3=False, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0, downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
_load_pretrained(model, model_urls['seresnext50_32x4d'], inchans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model
def seresnext101_32x4d(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs): def seresnext101_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
default_cfg = default_cfgs['seresnext101_32x4d']
model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
inplanes=64, input_3x3=False, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0, downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs) num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
_load_pretrained(model, model_urls['seresnext101_32x4d'], inchans) load_pretrained(model, default_cfg, num_classes, in_chans)
return model return model

@ -25,3 +25,12 @@ class TestTimePoolHead(nn.Module):
x = adaptive_avgmax_pool2d(x, 1) x = adaptive_avgmax_pool2d(x, 1)
return x.view(x.size(0), -1) return x.view(x.size(0), -1)
def apply_test_time_pool(model, args):
test_time_pool = False
if args.img_size > model.default_cfg['input_size'][-1] and not args.no_test_pool:
print('Target input size (%d) > pretrained default (%d), using test time pooling' %
(args.img_size, model.default_cfg['input_size'][-1]))
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
test_time_pool = True
return model, test_time_pool

@ -26,24 +26,24 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.nn import init from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import select_adaptive_pool2d
__all__ = ['xception'] __all__ = ['xception']
pretrained_config = { default_cfgs = {
'xception': { 'xception': {
'imagenet': { 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth', 'input_size': (3, 299, 299),
'input_space': 'RGB', 'mean': (0.5, 0.5, 0.5),
'input_size': [3, 299, 299], 'std': (0.5, 0.5, 0.5),
'input_range': [0, 1], 'num_classes': 1000,
'mean': [0.5, 0.5, 0.5], 'crop_pct': 0.8975,
'std': [0.5, 0.5, 0.5], 'first_conv': 'conv1',
'num_classes': 1000, 'classifier': 'fc'
'scale': 0.8975
# The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
}
} }
} }
@ -120,16 +120,18 @@ class Xception(nn.Module):
https://arxiv.org/pdf/1610.02357.pdf https://arxiv.org/pdf/1610.02357.pdf
""" """
def __init__(self, num_classes=1000): def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'):
""" Constructor """ Constructor
Args: Args:
num_classes: number of classes num_classes: number of classes
""" """
super(Xception, self).__init__() super(Xception, self).__init__()
self.drop_rate = drop_rate
self.global_pool = global_pool
self.num_classes = num_classes self.num_classes = num_classes
self.num_features = 2048 self.num_features = 2048
self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False) self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False)
self.bn1 = nn.BatchNorm2d(32) self.bn1 = nn.BatchNorm2d(32)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
@ -173,8 +175,9 @@ class Xception(nn.Module):
def get_classifier(self): def get_classifier(self):
return self.fc return self.fc
def reset_classifier(self, num_classes): def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool
del self.fc del self.fc
if num_classes: if num_classes:
self.fc = nn.Linear(self.num_features, num_classes) self.fc = nn.Linear(self.num_features, num_classes)
@ -212,24 +215,23 @@ class Xception(nn.Module):
x = self.relu(x) x = self.relu(x)
if pool: if pool:
x = F.adaptive_avg_pool2d(x, (1, 1)) x = select_adaptive_pool2d(x, pool_type=self.global_pool)
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
return x return x
def forward(self, input): def forward(self, input):
x = self.forward_features(input) x = self.forward_features(input)
if self.drop_rate:
F.dropout(x, self.drop_rate, training=self.training)
x = self.fc(x) x = self.fc(x)
return x return x
def xception(num_classes=1000, pretrained=False): def xception(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
model = Xception(num_classes=num_classes) default_cfg = default_cfgs['xception']
model = Xception(num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained: if pretrained:
config = pretrained_config['xception']['imagenet'] load_pretrained(model, default_cfg, num_classes, in_chans)
assert num_classes == config['num_classes'], \
"num_classes should be {}, but is {}".format(config['num_classes'], num_classes)
model = Xception(num_classes=num_classes)
model.load_state_dict(model_zoo.load_url(config['url']))
return model return model

@ -1,2 +1,3 @@
from optim.adabound import AdaBound from optim.adabound import AdaBound
from optim.nadam import Nadam from optim.nadam import Nadam
from optim.optim_factory import create_optimizer

@ -0,0 +1,30 @@
from torch import optim as optim
from optim import Nadam, AdaBound
def create_optimizer(args, parameters):
if args.opt.lower() == 'sgd':
optimizer = optim.SGD(
parameters, lr=args.lr,
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
elif args.opt.lower() == 'adam':
optimizer = optim.Adam(
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'nadam':
optimizer = Nadam(
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'adabound':
optimizer = AdaBound(
parameters, lr=args.lr / 100, weight_decay=args.weight_decay, eps=args.opt_eps,
final_lr=args.lr)
elif args.opt.lower() == 'adadelta':
optimizer = optim.Adadelta(
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'rmsprop':
optimizer = optim.RMSprop(
parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps,
momentum=args.momentum, weight_decay=args.weight_decay)
else:
assert False and "Invalid optimizer"
raise ValueError
return optimizer

@ -1,4 +1,5 @@
from .cosine_lr import CosineLRScheduler from scheduler.cosine_lr import CosineLRScheduler
from .plateau_lr import PlateauLRScheduler from scheduler.plateau_lr import PlateauLRScheduler
from .step_lr import StepLRScheduler from scheduler.step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler from scheduler.tanh_lr import TanhLRScheduler
from scheduler.scheduler_factory import create_scheduler

@ -0,0 +1,43 @@
from scheduler.cosine_lr import CosineLRScheduler
from scheduler.plateau_lr import PlateauLRScheduler
from scheduler.tanh_lr import TanhLRScheduler
from scheduler.step_lr import StepLRScheduler
def create_scheduler(args, optimizer):
num_epochs = args.epochs
#FIXME expose cycle parms of the scheduler config to arguments
if args.sched == 'cosine':
lr_scheduler = CosineLRScheduler(
optimizer,
t_initial=num_epochs,
t_mul=1.0,
lr_min=1e-5,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cycle_limit=1,
t_in_epochs=True,
)
num_epochs = lr_scheduler.get_cycle_length() + 10
elif args.sched == 'tanh':
lr_scheduler = TanhLRScheduler(
optimizer,
t_initial=num_epochs,
t_mul=1.0,
lr_min=1e-5,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cycle_limit=1,
t_in_epochs=True,
)
num_epochs = lr_scheduler.get_cycle_length() + 10
else:
lr_scheduler = StepLRScheduler(
optimizer,
decay_t=args.decay_epochs,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
)
return lr_scheduler, num_epochs

@ -1,7 +1,6 @@
import argparse import argparse
import time import time
from collections import OrderedDict
from datetime import datetime from datetime import datetime
try: try:
@ -12,17 +11,14 @@ except ImportError:
has_apex = False has_apex = False
from data import * from data import *
from models import model_factory from models import create_model, resume_checkpoint
from utils import * from utils import *
from optim import Nadam, AdaBound
from loss import LabelSmoothingCrossEntropy from loss import LabelSmoothingCrossEntropy
import scheduler from optim import create_optimizer
from scheduler import create_scheduler
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torch.distributed as dist import torch.distributed as dist
import torchvision.utils import torchvision.utils
@ -33,6 +29,8 @@ parser.add_argument('data', metavar='DIR',
help='path to dataset') help='path to dataset')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL', parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"') help='Name of model to train (default: "countception"')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"') help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
@ -120,10 +118,13 @@ def main():
r = torch.distributed.get_rank() r = torch.distributed.get_rank()
if args.distributed: if args.distributed:
print('Training in distributed mode with %d processes, 1 GPU per process. Process %d.' print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.world_size, r)) % (r, args.world_size))
else: else:
print('Training with a single process with %d GPUs.' % args.num_gpu) print('Training with a single process on %d GPUs.' % args.num_gpu)
# FIXME seed handling for multi-process distributed?
torch.manual_seed(args.seed)
output_dir = '' output_dir = ''
if args.local_rank == 0: if args.local_rank == 0:
@ -137,80 +138,21 @@ def main():
str(args.img_size)]) str(args.img_size)])
output_dir = get_outdir(output_base, 'train', exp_name) output_dir = get_outdir(output_base, 'train', exp_name)
batch_size = args.batch_size model = create_model(
torch.manual_seed(args.seed)
data_mean, data_std = get_model_meanstd(args.model)
dataset_train = Dataset(os.path.join(args.data, 'train'))
loader_train = create_loader(
dataset_train,
img_size=args.img_size,
batch_size=batch_size,
is_training=True,
use_prefetcher=True,
random_erasing=0.3,
mean=data_mean,
std=data_std,
num_workers=args.workers,
distributed=args.distributed,
)
dataset_eval = Dataset(os.path.join(args.data, 'validation'))
loader_eval = create_loader(
dataset_eval,
img_size=args.img_size,
batch_size=4 * args.batch_size,
is_training=False,
use_prefetcher=True,
mean=data_mean,
std=data_std,
num_workers=args.workers,
distributed=args.distributed,
)
model = model_factory.create_model(
args.model, args.model,
pretrained=args.pretrained, pretrained=args.pretrained,
num_classes=1000, num_classes=args.num_classes,
drop_rate=args.drop, drop_rate=args.drop,
global_pool=args.gp, global_pool=args.gp,
checkpoint_path=args.initial_checkpoint) checkpoint_path=args.initial_checkpoint)
data_mean, data_std = get_mean_and_std(model, args)
# optionally resume from a checkpoint # optionally resume from a checkpoint
start_epoch = 0 if args.start_epoch is None else args.start_epoch start_epoch = 0
optimizer_state = None optimizer_state = None
if args.resume: if args.resume:
if os.path.isfile(args.resume): start_epoch, optimizer_state = resume_checkpoint(model, args.resume, args.start_epoch)
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('module'):
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
if 'optimizer' in checkpoint:
optimizer_state = checkpoint['optimizer']
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
start_epoch = checkpoint['epoch'] if args.start_epoch is None else args.start_epoch
else:
model.load_state_dict(checkpoint)
else:
print("=> no checkpoint found at '{}'".format(args.resume))
return False
if args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = train_loss_fn
if args.num_gpu > 1: if args.num_gpu > 1:
if args.amp: if args.amp:
@ -237,9 +179,55 @@ def main():
model = DDP(model, delay_allreduce=True) model = DDP(model, delay_allreduce=True)
lr_scheduler, num_epochs = create_scheduler(args, optimizer) lr_scheduler, num_epochs = create_scheduler(args, optimizer)
if start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0: if args.local_rank == 0:
print('Scheduled epochs: ', num_epochs) print('Scheduled epochs: ', num_epochs)
train_dir = os.path.join(args.data, 'train')
if not os.path.exists(train_dir):
print('Error: training folder does not exist at: %s' % train_dir)
exit(1)
dataset_train = Dataset(train_dir)
loader_train = create_loader(
dataset_train,
img_size=args.img_size,
batch_size=args.batch_size,
is_training=True,
use_prefetcher=True,
random_erasing=0.3,
mean=data_mean,
std=data_std,
num_workers=args.workers,
distributed=args.distributed,
)
eval_dir = os.path.join(args.data, 'validation')
if not os.path.isdir(eval_dir):
print('Error: validation folder does not exist at: %s' % eval_dir)
exit(1)
dataset_eval = Dataset(eval_dir)
loader_eval = create_loader(
dataset_eval,
img_size=args.img_size,
batch_size=4 * args.batch_size,
is_training=False,
use_prefetcher=True,
mean=data_mean,
std=data_std,
num_workers=args.workers,
distributed=args.distributed,
)
if args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = train_loss_fn
eval_metric = args.eval_metric eval_metric = args.eval_metric
saver = None saver = None
if output_dir: if output_dir:
@ -429,76 +417,9 @@ def validate(model, loader, loss_fn, args):
return metrics return metrics
def create_optimizer(args, parameters):
if args.opt.lower() == 'sgd':
optimizer = optim.SGD(
parameters, lr=args.lr,
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
elif args.opt.lower() == 'adam':
optimizer = optim.Adam(
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'nadam':
optimizer = Nadam(
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'adabound':
optimizer = AdaBound(
parameters, lr=args.lr / 100, weight_decay=args.weight_decay, eps=args.opt_eps,
final_lr=args.lr)
elif args.opt.lower() == 'adadelta':
optimizer = optim.Adadelta(
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'rmsprop':
optimizer = optim.RMSprop(
parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps,
momentum=args.momentum, weight_decay=args.weight_decay)
else:
assert False and "Invalid optimizer"
raise ValueError
return optimizer
def create_scheduler(args, optimizer):
num_epochs = args.epochs
#FIXME expose cycle parms of the scheduler config to arguments
if args.sched == 'cosine':
lr_scheduler = scheduler.CosineLRScheduler(
optimizer,
t_initial=num_epochs,
t_mul=1.0,
lr_min=1e-5,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cycle_limit=1,
t_in_epochs=True,
)
num_epochs = lr_scheduler.get_cycle_length() + 10
elif args.sched == 'tanh':
lr_scheduler = scheduler.TanhLRScheduler(
optimizer,
t_initial=num_epochs,
t_mul=1.0,
lr_min=1e-5,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cycle_limit=1,
t_in_epochs=True,
)
num_epochs = lr_scheduler.get_cycle_length() + 10
else:
lr_scheduler = scheduler.StepLRScheduler(
optimizer,
decay_t=args.decay_epochs,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
)
return lr_scheduler, num_epochs
def reduce_tensor(tensor, n): def reduce_tensor(tensor, n):
rt = tensor.clone() rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM) dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= n rt /= n
return rt return rt

@ -6,13 +6,14 @@ import argparse
import os import os
import time import time
import torch import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn import torch.nn as nn
import torch.nn.parallel import torch.nn.parallel
from models import create_model, load_checkpoint, TestTimePoolHead from models import create_model, apply_test_time_pool
from data import Dataset, create_loader, get_model_meanstd from data import Dataset, create_loader, get_mean_and_std
from utils import accuracy, AverageMeter
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
parser.add_argument('data', metavar='DIR', parser.add_argument('data', metavar='DIR',
@ -25,6 +26,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)') metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=224, type=int, parser.add_argument('--img-size', default=224, type=int,
metavar='N', help='Input image dimension') metavar='N', help='Input image dimension')
parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset')
parser.add_argument('--print-freq', '-p', default=10, type=int, parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)') metavar='N', help='print frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@ -41,25 +44,19 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# create model # create model
num_classes = 1000
model = create_model( model = create_model(
args.model, args.model,
num_classes=num_classes, num_classes=args.num_classes,
pretrained=args.pretrained) in_chans=3,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint)
print('Model %s created, param count: %d' % print('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()]))) (args.model, sum([m.numel() for m in model.parameters()])))
# load a checkpoint data_mean, data_std = get_mean_and_std(model, args)
if not args.pretrained:
if not load_checkpoint(model, args.checkpoint):
exit(1)
test_time_pool = False model, test_time_pool = apply_test_time_pool(model, args)
# FIXME make this work for networks with default img size != 224 and default pool k != 7
if args.img_size > 224 and not args.no_test_pool:
model = TestTimePoolHead(model)
test_time_pool = True
if args.num_gpu > 1: if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
@ -69,14 +66,11 @@ def main():
# define loss function (criterion) and optimizer # define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss().cuda()
cudnn.benchmark = True
data_mean, data_std = get_model_meanstd(args.model)
loader = create_loader( loader = create_loader(
Dataset(args.data), Dataset(args.data),
img_size=args.img_size, img_size=args.img_size,
batch_size=args.batch_size, batch_size=args.batch_size,
use_prefetcher=True, use_prefetcher=False,
mean=data_mean, mean=data_mean,
std=data_std, std=data_std,
num_workers=args.workers, num_workers=args.workers,
@ -111,51 +105,17 @@ def main():
if i % args.print_freq == 0: if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t' print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(loader), batch_time=batch_time, loss=losses, i, len(loader), batch_time=batch_time,
top1=top1, top5=top5)) rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5))
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
if __name__ == '__main__': if __name__ == '__main__':
main() main()

Loading…
Cancel
Save