You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/data/config.py

79 lines
2.8 KiB

import logging
from .constants import *
_logger = logging.getLogger(__name__)
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
new_config = {}
default_cfg = default_cfg
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
default_cfg = model.default_cfg
# Resolve input/image size
in_chans = 3
if 'chans' in args and args['chans'] is not None:
in_chans = args['chans']
input_size = (in_chans, 224, 224)
if 'input_size' in args and args['input_size'] is not None:
assert isinstance(args['input_size'], (tuple, list))
assert len(args['input_size']) == 3
input_size = tuple(args['input_size'])
in_chans = input_size[0] # input_size overrides in_chans
elif 'img_size' in args and args['img_size'] is not None:
assert isinstance(args['img_size'], int)
input_size = (in_chans, args['img_size'], args['img_size'])
else:
if use_test_size and 'test_input_size' in default_cfg:
input_size = default_cfg['test_input_size']
elif 'input_size' in default_cfg:
input_size = default_cfg['input_size']
new_config['input_size'] = input_size
# resolve interpolation method
new_config['interpolation'] = 'bicubic'
if 'interpolation' in args and args['interpolation']:
new_config['interpolation'] = args['interpolation']
elif 'interpolation' in default_cfg:
new_config['interpolation'] = default_cfg['interpolation']
# resolve dataset + model mean for normalization
new_config['mean'] = IMAGENET_DEFAULT_MEAN
if 'mean' in args and args['mean'] is not None:
mean = tuple(args['mean'])
if len(mean) == 1:
mean = tuple(list(mean) * in_chans)
else:
assert len(mean) == in_chans
new_config['mean'] = mean
elif 'mean' in default_cfg:
new_config['mean'] = default_cfg['mean']
# resolve dataset + model std deviation for normalization
new_config['std'] = IMAGENET_DEFAULT_STD
if 'std' in args and args['std'] is not None:
std = tuple(args['std'])
if len(std) == 1:
std = tuple(list(std) * in_chans)
else:
assert len(std) == in_chans
new_config['std'] = std
elif 'std' in default_cfg:
new_config['std'] = default_cfg['std']
# resolve default crop percentage
new_config['crop_pct'] = DEFAULT_CROP_PCT
if 'crop_pct' in args and args['crop_pct'] is not None:
new_config['crop_pct'] = args['crop_pct']
elif 'crop_pct' in default_cfg:
new_config['crop_pct'] = default_cfg['crop_pct']
if verbose:
_logger.info('Data processing configuration for current model + dataset:')
for n, v in new_config.items():
_logger.info('\t%s: %s' % (n, str(v)))
return new_config