Merge pull request #16 from rwightman/misc-epoch

Weights, arguments, epoch counting, and more
pull/19/head
Ross Wightman 6 years ago committed by GitHub
commit 8a05b8d555
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -67,6 +67,7 @@ I've leveraged the training scripts in this repository to train a few of the mod
|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling |
|---|---|---|---|---|
| resnext50_32x4d | 78.512 (21.488) | 94.042 (5.958) | 25M | bicubic |
| resnet50 | 78.470 (21.530) | 94.266 (5.734) | 25.6M | bicubic |
| seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic |
| efficientnet_b0 | 76.912 (23.088) | 93.210 (6.790) | 5.29M | bicubic |
| mobilenetv3_100 | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic |
@ -86,7 +87,7 @@ I've leveraged the training scripts in this repository to train a few of the mod
#### @ 260x260
|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling |
|---|---|---|---|---|
| efficientnet_b2 | 79.668 (20.332) | 94.634 (5.366) | 9.11M | bicubic |
| efficientnet_b2 | 79.760 (20.240) | 94.714 (5.286) | 9.11M | bicubic |
### Ported Weights

@ -70,7 +70,7 @@ def main():
logging.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
config = resolve_data_config(model, args)
config = resolve_data_config(vars(args), model=model)
model, test_time_pool = apply_test_time_pool(model, config, args)
if args.num_gpu > 1:

@ -2,35 +2,43 @@ import logging
from .constants import *
def resolve_data_config(model, args, default_cfg={}, verbose=True):
def resolve_data_config(args, default_cfg={}, model=None, verbose=True):
new_config = {}
default_cfg = default_cfg
if not default_cfg and hasattr(model, 'default_cfg'):
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
default_cfg = model.default_cfg
# Resolve input/image size
# FIXME grayscale/chans arg to use different # channels?
in_chans = 3
if 'chans' in args and args['chans'] is not None:
in_chans = args['chans']
input_size = (in_chans, 224, 224)
if args.img_size is not None:
# FIXME support passing img_size as tuple, non-square
assert isinstance(args.img_size, int)
input_size = (in_chans, args.img_size, args.img_size)
if 'input_size' in args and args['input_size'] is not None:
assert isinstance(args['input_size'], (tuple, list))
assert len(args['input_size']) == 3
input_size = tuple(args['input_size'])
in_chans = input_size[0] # input_size overrides in_chans
elif 'img_size' in args and args['img_size'] is not None:
assert isinstance(args['img_size'], int)
input_size = (in_chans, args['img_size'], args['img_size'])
elif 'input_size' in default_cfg:
input_size = default_cfg['input_size']
new_config['input_size'] = input_size
# resolve interpolation method
new_config['interpolation'] = 'bilinear'
if args.interpolation:
new_config['interpolation'] = args.interpolation
new_config['interpolation'] = 'bicubic'
if 'interpolation' in args and args['interpolation']:
new_config['interpolation'] = args['interpolation']
elif 'interpolation' in default_cfg:
new_config['interpolation'] = default_cfg['interpolation']
# resolve dataset + model mean for normalization
new_config['mean'] = get_mean_by_model(args.model)
if args.mean is not None:
mean = tuple(args.mean)
new_config['mean'] = IMAGENET_DEFAULT_MEAN
if 'model' in args:
new_config['mean'] = get_mean_by_model(args['model'])
if 'mean' in args and args['mean'] is not None:
mean = tuple(args['mean'])
if len(mean) == 1:
mean = tuple(list(mean) * in_chans)
else:
@ -40,9 +48,11 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
new_config['mean'] = default_cfg['mean']
# resolve dataset + model std deviation for normalization
new_config['std'] = get_std_by_model(args.model)
if args.std is not None:
std = tuple(args.std)
new_config['std'] = IMAGENET_DEFAULT_STD
if 'model' in args:
new_config['std'] = get_std_by_model(args['model'])
if 'std' in args and args['std'] is not None:
std = tuple(args['std'])
if len(std) == 1:
std = tuple(list(std) * in_chans)
else:
@ -53,7 +63,9 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
# resolve default crop percentage
new_config['crop_pct'] = DEFAULT_CROP_PCT
if 'crop_pct' in default_cfg:
if 'crop_pct' in args and args['crop_pct'] is not None:
new_config['crop_pct'] = args['crop_pct']
elif 'crop_pct' in default_cfg:
new_config['crop_pct'] = default_cfg['crop_pct']
if verbose:
@ -64,29 +76,11 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
return new_config
def get_mean_by_name(name):
if name == 'dpn':
return IMAGENET_DPN_MEAN
elif name == 'inception' or name == 'le':
return IMAGENET_INCEPTION_MEAN
else:
return IMAGENET_DEFAULT_MEAN
def get_std_by_name(name):
if name == 'dpn':
return IMAGENET_DPN_STD
elif name == 'inception' or name == 'le':
return IMAGENET_INCEPTION_STD
else:
return IMAGENET_DEFAULT_STD
def get_mean_by_model(model_name):
model_name = model_name.lower()
if 'dpn' in model_name:
return IMAGENET_DPN_STD
elif 'ception' in model_name or 'nasnet' in model_name:
elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name):
return IMAGENET_INCEPTION_MEAN
else:
return IMAGENET_DEFAULT_MEAN
@ -96,7 +90,7 @@ def get_std_by_model(model_name):
model_name = model_name.lower()
if 'dpn' in model_name:
return IMAGENET_DEFAULT_STD
elif 'ception' in model_name or 'nasnet' in model_name:
elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name):
return IMAGENET_INCEPTION_STD
else:
return IMAGENET_DEFAULT_STD

@ -86,6 +86,7 @@ def create_loader(
use_prefetcher=True,
rand_erase_prob=0.,
rand_erase_mode='const',
color_jitter=0.4,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
@ -107,6 +108,7 @@ def create_loader(
if is_training:
transform = transforms_imagenet_train(
img_size,
color_jitter=color_jitter,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,

@ -6,12 +6,13 @@ import torch
def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'):
# NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
# paths, flip the order so normal is run on CPU if this becomes a problem
# ie torch.empty(patch_size, dtype=dtype).normal_().to(device=device)
# Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
# will revert back to doing normal_() on GPU when it's in next release
if per_pixel:
return torch.empty(
patch_size, dtype=dtype, device=device).normal_()
patch_size, dtype=dtype).normal_().to(device=device)
elif rand_color:
return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
return torch.empty((patch_size[0], 1, 1), dtype=dtype).normal_().to(device=device)
else:
return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)

@ -156,7 +156,7 @@ class RandomResizedCropAndInterpolation(object):
def transforms_imagenet_train(
img_size=224,
scale=(0.08, 1.0),
color_jitter=(0.4, 0.4, 0.4),
color_jitter=0.4,
interpolation='random',
random_erasing=0.4,
random_erasing_mode='const',
@ -164,6 +164,13 @@ def transforms_imagenet_train(
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
):
if isinstance(color_jitter, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue
assert len(color_jitter) in (3, 4)
else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3
tfl = [
RandomResizedCropAndInterpolation(

@ -35,9 +35,9 @@ def _cfg(url=''):
default_cfgs = {
'dpn68': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'),
'dpn68b_extra': _cfg(
'dpn68b': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68b_extra-84854c156.pth'),
'dpn92_extra': _cfg(
'dpn92': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'),
'dpn98': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'),

@ -84,7 +84,7 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
input_size=(3, 240, 240), pool_size=(8, 8), interpolation='bicubic', crop_pct=0.882),
'efficientnet_b2': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-d4105846.pth',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth',
input_size=(3, 260, 260), pool_size=(9, 9), interpolation='bicubic', crop_pct=0.890),
'efficientnet_b3': _cfg(
url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),

@ -50,11 +50,9 @@ default_cfgs = {
'gluon_resnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'),
'gluon_resnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'),
'gluon_resnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'),
'gluon_resnext152_32x4d': _cfg(url=''),
'gluon_seresnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'),
'gluon_seresnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'),
'gluon_seresnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'),
'gluon_seresnext152_32x4d': _cfg(url=''),
'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth'),
}
@ -617,20 +615,6 @@ def gluon_resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwa
return model
@register_model
def gluon_resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt152-32x4d model.
"""
default_cfg = default_cfgs['gluon_resnext152_32x4d']
model = GluonResNet(
BottleneckGl, [3, 8, 36, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a SEResNeXt50-32x4d model.
@ -673,20 +657,6 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k
return model
@register_model
def gluon_seresnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a SEResNeXt152-32x4d model.
"""
default_cfg = default_cfgs['gluon_seresnext152_32x4d']
model = GluonResNet(
BottleneckGl, [3, 8, 36, 3], cardinality=32, base_width=4, use_se=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
#if pretrained:
# load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs an SENet-154 model.

@ -28,8 +28,9 @@ def load_checkpoint(model, checkpoint_path, use_ema=False):
raise FileNotFoundError()
def resume_checkpoint(model, checkpoint_path, start_epoch=None):
def resume_checkpoint(model, checkpoint_path):
optimizer_state = None
resume_epoch = None
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
@ -40,13 +41,15 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None):
model.load_state_dict(new_state_dict)
if 'optimizer' in checkpoint:
optimizer_state = checkpoint['optimizer']
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
if 'epoch' in checkpoint:
resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1:
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else:
model.load_state_dict(checkpoint)
start_epoch = 0 if start_epoch is None else start_epoch
logging.info("Loaded checkpoint '{}'".format(checkpoint_path))
return optimizer_state, start_epoch
return optimizer_state, resume_epoch
else:
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()

@ -5,22 +5,36 @@ from collections import defaultdict
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules']
_module_to_models = defaultdict(set)
_model_to_module = {}
_model_entrypoints = {}
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present
def register_model(fn):
# lookup containing module
mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''
# add model to __all__ in module
model_name = fn.__name__
if hasattr(mod, '__all__'):
mod.__all__.append(fn.__name__)
mod.__all__.append(model_name)
else:
mod.__all__ = [fn.__name__]
_model_entrypoints[fn.__name__] = fn
_model_to_module[fn.__name__] = module_name
_module_to_models[module_name].add(fn.__name__)
mod.__all__ = [model_name]
# add entries to registry dict/sets
_model_entrypoints[model_name] = fn
_model_to_module[model_name] = module_name
_module_to_models[module_name].add(model_name)
has_pretrained = False # check if model has a pretrained url to allow filtering on this
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
if has_pretrained:
_model_has_pretrained.add(model_name)
return fn
@ -28,7 +42,7 @@ def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models(filter='', module=''):
def list_models(filter='', module='', pretrained=False):
""" Return list of available model names, sorted alphabetically
Args:
@ -45,6 +59,8 @@ def list_models(filter='', module=''):
models = _model_entrypoints.keys()
if filter:
models = fnmatch.filter(models, filter)
if pretrained:
models = _model_has_pretrained.intersection(models)
return list(sorted(models, key=_natural_key))

@ -33,14 +33,22 @@ default_cfgs = {
'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'),
'resnet34': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'),
'resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'),
'resnet50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth',
interpolation='bicubic'),
'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'),
'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'),
'resnext50_32x4d': _cfg(url='https://www.dropbox.com/s/yxci33lfew51p6a/resnext50_32x4d-068914d1.pth?dl=1',
interpolation='bicubic'),
'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'),
'tv_resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'),
'wide_resnet50_2': _cfg(url='https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth'),
'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'),
'resnext50_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d-068914d1.pth',
interpolation='bicubic'),
'resnext101_32x4d': _cfg(url=''),
'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'),
'resnext101_64x4d': _cfg(url=''),
'resnext152_32x4d': _cfg(url=''),
'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'),
'ig_resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth'),
'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'),
'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'),
@ -285,6 +293,61 @@ def resnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return model
@register_model
def tv_resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-34 model with original Torchvision weights.
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfgs['tv_resnet34']
if pretrained:
load_pretrained(model, model.default_cfg, num_classes, in_chans)
return model
@register_model
def tv_resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-50 model with original Torchvision weights.
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfgs['tv_resnet50']
if pretrained:
load_pretrained(model, model.default_cfg, num_classes, in_chans)
return model
@register_model
def wide_resnet50_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Wide ResNet-50-2 model.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
"""
model = ResNet(
Bottleneck, [3, 4, 6, 3], base_width=128,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfgs['wide_resnet50_2']
if pretrained:
load_pretrained(model, model.default_cfg, num_classes, in_chans)
return model
@register_model
def wide_resnet101_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Wide ResNet-101-2 model.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same.
"""
model = ResNet(
Bottleneck, [3, 4, 23, 3], base_width=128,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfgs['wide_resnet101_2']
if pretrained:
load_pretrained(model, model.default_cfg, num_classes, in_chans)
return model
@register_model
def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt50-32x4d model.
@ -301,7 +364,7 @@ def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
@register_model
def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt-101 model.
"""Constructs a ResNeXt-101 32x4d model.
"""
default_cfg = default_cfgs['resnext101_32x4d']
model = ResNet(
@ -313,6 +376,20 @@ def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return model
@register_model
def resnext101_32x8d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt-101 32x8d model.
"""
default_cfg = default_cfgs['resnext101_32x8d']
model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt101-64x4d model.
@ -328,12 +405,12 @@ def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
@register_model
def resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt152-32x4d model.
def tv_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt50-32x4d model with original Torchvision weights.
"""
default_cfg = default_cfgs['resnext152_32x4d']
default_cfg = default_cfgs['tv_resnext50_32x4d']
model = ResNet(
Bottleneck, [3, 8, 36, 3], cardinality=32, base_width=4,
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:

@ -56,7 +56,7 @@ class Scheduler:
def step(self, epoch: int, metric: float = None) -> None:
self.metric = metric
values = self.get_epoch_values(epoch + 1) # +1 to calculate for next epoch
values = self.get_epoch_values(epoch)
if values is not None:
self.update_groups(values)

@ -83,7 +83,8 @@ class CheckpointSaver:
'arch': args.model,
'state_dict': get_state_dict(model),
'optimizer': optimizer.state_dict(),
'args': args
'args': args,
'version': 2, # version < 2 increments epoch before save
}
if model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(model_ema)

@ -1 +1 @@
__version__ = '0.1.2'
__version__ = '0.1.4'

@ -27,22 +27,21 @@ import torchvision.utils
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description='Training')
# Dataset / Model parameters
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
@ -53,8 +52,24 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N',
help='initial input batch size for training (default: 0)')
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
help='Dropout rate (default: 0.)')
# Optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.0001,
help='weight decay (default: 0.0001)')
# Learning rate schedule parameters
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "step"')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
@ -65,40 +80,34 @@ parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "step"')
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
help='Dropout rate (default: 0.)')
# Augmentation parameters
parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)')
parser.add_argument('--remode', type=str, default='const',
help='Random erase mode (default: "const")')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.0001,
help='weight decay (default: 0.0001)')
parser.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='label smoothing (default: 0.1)')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
@ -109,10 +118,6 @@ parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='path to init checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
@ -125,6 +130,8 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "prec1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument("--local_rank", default=0, type=int)
@ -174,13 +181,13 @@ def main():
logging.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
# optionally resume from a checkpoint
start_epoch = 0
optimizer_state = None
resume_epoch = None
if args.resume:
optimizer_state, start_epoch = resume_checkpoint(model, args.resume, args.start_epoch)
optimizer_state, resume_epoch = resume_checkpoint(model, args.resume)
if args.num_gpu > 1:
if args.amp:
@ -232,8 +239,15 @@ def main():
# NOTE: EMA model does not need to be wrapped by DDP
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
logging.info('Scheduled epochs: {}'.format(num_epochs))
@ -255,6 +269,7 @@ def main():
use_prefetcher=args.prefetcher,
rand_erase_prob=args.reprob,
rand_erase_mode=args.remode,
color_jitter=args.color_jitter,
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
@ -327,7 +342,8 @@ def main():
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
lr_scheduler.step(epoch, eval_metrics[eval_metric])
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
@ -338,9 +354,7 @@ def main():
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(
model, optimizer, args,
epoch=epoch + 1,
model_ema=model_ema,
metric=save_metric)
epoch=epoch, model_ema=model_ema, metric=save_metric)
except KeyboardInterrupt:
pass
@ -433,9 +447,8 @@ def train_epoch(
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
save_epoch = epoch + 1 if last_batch else epoch
saver.save_recovery(
model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx)
model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

@ -71,7 +71,7 @@ def validate(args):
param_count = sum([m.numel() for m in model.parameters()])
logging.info('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(model, args)
data_config = resolve_data_config(vars(args), model=model)
model, test_time_pool = apply_test_time_pool(model, data_config, args)
if args.num_gpu > 1:

Loading…
Cancel
Save