import logging from dataclasses import dataclass from typing import Tuple, Optional, Union from .constants import * _logger = logging.getLogger(__name__) @dataclass class AugCfg: scale_range: Tuple[float, float] = (0.08, 1.0) ratio_range: Tuple[float, float] = (3 / 4, 4 / 3) hflip_prob: float = 0.5 vflip_prob: float = 0. color_jitter: float = 0.4 auto_augment: Optional[str] = None re_prob: float = 0. re_mode: str = 'const' re_count: int = 1 num_aug_splits: int = 0 @dataclass class PreprocessCfg: input_size: Tuple[int, int, int] = (3, 224, 224) mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN std: Tuple[float, ...] = IMAGENET_DEFAULT_STD interpolation: str = 'bilinear' crop_pct: float = 0.875 aug: AugCfg = None @dataclass class MixupCfg: prob: float = 1.0 switch_prob: float = 0.5 mixup_alpha: float = 1. cutmix_alpha: float = 0. cutmix_minmax: Optional[Tuple[float, float]] = None mode: str = 'batch' correct_lam: bool = True label_smoothing: float = 0.1 num_classes: int = 0 def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False): new_config = {} default_cfg = default_cfg if not default_cfg and model is not None and hasattr(model, 'default_cfg'): default_cfg = model.default_cfg # Resolve input/image size in_chans = 3 if 'chans' in args and args['chans'] is not None: in_chans = args['chans'] input_size = (in_chans, 224, 224) if 'input_size' in args and args['input_size'] is not None: assert isinstance(args['input_size'], (tuple, list)) assert len(args['input_size']) == 3 input_size = tuple(args['input_size']) in_chans = input_size[0] # input_size overrides in_chans elif 'img_size' in args and args['img_size'] is not None: assert isinstance(args['img_size'], int) input_size = (in_chans, args['img_size'], args['img_size']) else: if use_test_size and 'test_input_size' in default_cfg: input_size = default_cfg['test_input_size'] elif 'input_size' in default_cfg: input_size = default_cfg['input_size'] new_config['input_size'] = input_size # resolve interpolation method new_config['interpolation'] = 'bicubic' if 'interpolation' in args and args['interpolation']: new_config['interpolation'] = args['interpolation'] elif 'interpolation' in default_cfg: new_config['interpolation'] = default_cfg['interpolation'] # resolve dataset + model mean for normalization new_config['mean'] = IMAGENET_DEFAULT_MEAN if 'mean' in args and args['mean'] is not None: mean = tuple(args['mean']) if len(mean) == 1: mean = tuple(list(mean) * in_chans) else: assert len(mean) == in_chans new_config['mean'] = mean elif 'mean' in default_cfg: new_config['mean'] = default_cfg['mean'] # resolve dataset + model std deviation for normalization new_config['std'] = IMAGENET_DEFAULT_STD if 'std' in args and args['std'] is not None: std = tuple(args['std']) if len(std) == 1: std = tuple(list(std) * in_chans) else: assert len(std) == in_chans new_config['std'] = std elif 'std' in default_cfg: new_config['std'] = default_cfg['std'] # resolve default crop percentage new_config['crop_pct'] = DEFAULT_CROP_PCT if 'crop_pct' in args and args['crop_pct'] is not None: new_config['crop_pct'] = args['crop_pct'] elif 'crop_pct' in default_cfg: new_config['crop_pct'] = default_cfg['crop_pct'] if getattr(args, 'mixup', 0) > 0 \ or getattr(args, 'cutmix', 0) > 0. \ or getattr(args, 'cutmix_minmax', None) is not None: new_config['mixup'] = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if verbose: _logger.info('Data processing configuration for current model + dataset:') for n, v in new_config.items(): _logger.info('\t%s: %s' % (n, str(v))) return new_config