Cleanup resolve data config fns, add 'model' variant that takes model as first arg, make 'args' arg optional in original fn

pull/1641/head
Ross Wightman 2 years ago
parent bed350f5e5
commit e9f1376cde

@ -1,6 +1,6 @@
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform rand_augment_transform, auto_augment_transform
from .config import resolve_data_config from .config import resolve_data_config, resolve_model_data_config
from .constants import * from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset from .dataset_factory import create_dataset

@ -6,16 +6,18 @@ _logger = logging.getLogger(__name__)
def resolve_data_config( def resolve_data_config(
args, args=None,
default_cfg=None, pretrained_cfg=None,
model=None, model=None,
use_test_size=False, use_test_size=False,
verbose=False verbose=False
): ):
new_config = {} assert model or args or pretrained_cfg, "At least one of model, args, or pretrained_cfg required for data config."
default_cfg = default_cfg or {} args = args or {}
if not default_cfg and model is not None and hasattr(model, 'default_cfg'): pretrained_cfg = pretrained_cfg or {}
default_cfg = model.default_cfg if not pretrained_cfg and model is not None and hasattr(model, 'pretrained_cfg'):
pretrained_cfg = model.pretrained_cfg
data_config = {}
# Resolve input/image size # Resolve input/image size
in_chans = 3 in_chans = 3
@ -32,65 +34,94 @@ def resolve_data_config(
assert isinstance(args['img_size'], int) assert isinstance(args['img_size'], int)
input_size = (in_chans, args['img_size'], args['img_size']) input_size = (in_chans, args['img_size'], args['img_size'])
else: else:
if use_test_size and default_cfg.get('test_input_size', None) is not None: if use_test_size and pretrained_cfg.get('test_input_size', None) is not None:
input_size = default_cfg['test_input_size'] input_size = pretrained_cfg['test_input_size']
elif default_cfg.get('input_size', None) is not None: elif pretrained_cfg.get('input_size', None) is not None:
input_size = default_cfg['input_size'] input_size = pretrained_cfg['input_size']
new_config['input_size'] = input_size data_config['input_size'] = input_size
# resolve interpolation method # resolve interpolation method
new_config['interpolation'] = 'bicubic' data_config['interpolation'] = 'bicubic'
if args.get('interpolation', None): if args.get('interpolation', None):
new_config['interpolation'] = args['interpolation'] data_config['interpolation'] = args['interpolation']
elif default_cfg.get('interpolation', None): elif pretrained_cfg.get('interpolation', None):
new_config['interpolation'] = default_cfg['interpolation'] data_config['interpolation'] = pretrained_cfg['interpolation']
# resolve dataset + model mean for normalization # resolve dataset + model mean for normalization
new_config['mean'] = IMAGENET_DEFAULT_MEAN data_config['mean'] = IMAGENET_DEFAULT_MEAN
if args.get('mean', None) is not None: if args.get('mean', None) is not None:
mean = tuple(args['mean']) mean = tuple(args['mean'])
if len(mean) == 1: if len(mean) == 1:
mean = tuple(list(mean) * in_chans) mean = tuple(list(mean) * in_chans)
else: else:
assert len(mean) == in_chans assert len(mean) == in_chans
new_config['mean'] = mean data_config['mean'] = mean
elif default_cfg.get('mean', None): elif pretrained_cfg.get('mean', None):
new_config['mean'] = default_cfg['mean'] data_config['mean'] = pretrained_cfg['mean']
# resolve dataset + model std deviation for normalization # resolve dataset + model std deviation for normalization
new_config['std'] = IMAGENET_DEFAULT_STD data_config['std'] = IMAGENET_DEFAULT_STD
if args.get('std', None) is not None: if args.get('std', None) is not None:
std = tuple(args['std']) std = tuple(args['std'])
if len(std) == 1: if len(std) == 1:
std = tuple(list(std) * in_chans) std = tuple(list(std) * in_chans)
else: else:
assert len(std) == in_chans assert len(std) == in_chans
new_config['std'] = std data_config['std'] = std
elif default_cfg.get('std', None): elif pretrained_cfg.get('std', None):
new_config['std'] = default_cfg['std'] data_config['std'] = pretrained_cfg['std']
# resolve default inference crop # resolve default inference crop
crop_pct = DEFAULT_CROP_PCT crop_pct = DEFAULT_CROP_PCT
if args.get('crop_pct', None): if args.get('crop_pct', None):
crop_pct = args['crop_pct'] crop_pct = args['crop_pct']
else: else:
if use_test_size and default_cfg.get('test_crop_pct', None): if use_test_size and pretrained_cfg.get('test_crop_pct', None):
crop_pct = default_cfg['test_crop_pct'] crop_pct = pretrained_cfg['test_crop_pct']
elif default_cfg.get('crop_pct', None): elif pretrained_cfg.get('crop_pct', None):
crop_pct = default_cfg['crop_pct'] crop_pct = pretrained_cfg['crop_pct']
new_config['crop_pct'] = crop_pct data_config['crop_pct'] = crop_pct
# resolve default crop percentage # resolve default crop percentage
crop_mode = DEFAULT_CROP_MODE crop_mode = DEFAULT_CROP_MODE
if args.get('crop_mode', None): if args.get('crop_mode', None):
crop_mode = args['crop_mode'] crop_mode = args['crop_mode']
elif default_cfg.get('crop_mode', None): elif pretrained_cfg.get('crop_mode', None):
crop_mode = default_cfg['crop_mode'] crop_mode = pretrained_cfg['crop_mode']
new_config['crop_mode'] = crop_mode data_config['crop_mode'] = crop_mode
if verbose: if verbose:
_logger.info('Data processing configuration for current model + dataset:') _logger.info('Data processing configuration for current model + dataset:')
for n, v in new_config.items(): for n, v in data_config.items():
_logger.info('\t%s: %s' % (n, str(v))) _logger.info('\t%s: %s' % (n, str(v)))
return new_config return data_config
def resolve_model_data_config(
model,
args=None,
pretrained_cfg=None,
use_test_size=False,
verbose=False,
):
""" Resolve Model Data Config
This is equivalent to resolve_data_config() but with arguments re-ordered to put model first.
Args:
model (nn.Module): the model instance
args (dict): command line arguments / configuration in dict form (overrides pretrained_cfg)
pretrained_cfg (dict): pretrained model config (overrides pretrained_cfg attached to model)
use_test_size (bool): use the test time input resolution (if one exists) instead of default train resolution
verbose (bool): enable extra logging of resolved values
Returns:
dictionary of config
"""
return resolve_data_config(
args=args,
pretrained_cfg=pretrained_cfg,
model=model,
use_test_size=use_test_size,
verbose=verbose,
)

Loading…
Cancel
Save