From 0ed0cc7ebad0e14be66b3b8ba3594230944b531e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 21 Nov 2022 16:30:56 -0800 Subject: [PATCH] Add crop_mode for pretraind config / image transforms. Add support for dynamo compilation to benchmark/train/validate --- benchmark.py | 38 ++++++-- timm/data/config.py | 20 ++++- timm/data/constants.py | 1 + timm/data/loader.py | 2 + timm/data/tf_preprocessing.py | 3 +- timm/data/transforms.py | 151 +++++++++++++++++++++++++++++++- timm/data/transforms_factory.py | 64 ++++++++++---- timm/models/_pretrained.py | 2 +- train.py | 36 ++++++-- validate.py | 48 +++++++--- 10 files changed, 310 insertions(+), 55 deletions(-) diff --git a/benchmark.py b/benchmark.py index a03c1982..b2cac8a3 100755 --- a/benchmark.py +++ b/benchmark.py @@ -56,6 +56,13 @@ try: except ImportError as e: has_functorch = False +try: + import torch._dynamo + has_dynamo = True +except ImportError: + has_dynamo = False + pass + if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True @@ -106,13 +113,19 @@ parser.add_argument('--precision', default='float32', type=str, help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)') parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") +parser.add_argument('--dynamo-backend', default=None, type=str, + help="Select dynamo backend. Default: None") +parser.add_argument('--fast-norm', default=False, action='store_true', + help='enable experimental fast-norm') + +# codegen (model compilation) options scripting_group = parser.add_mutually_exclusive_group() scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', - help='convert model torchscript for inference') + help='convert model torchscript for inference') scripting_group.add_argument('--aot-autograd', default=False, action='store_true', - help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") -scripting_group.add_argument('--fast-norm', default=False, action='store_true', - help='enable experimental fast-norm') + help="Enable AOT Autograd optimization.") +scripting_group.add_argument('--dynamo', default=False, action='store_true', + help="Enable Dynamo optimization.") # train optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', @@ -206,6 +219,8 @@ class BenchmarkRunner: device='cuda', torchscript=False, aot_autograd=False, + dynamo=False, + dynamo_backend=None, precision='float32', fuser='', num_warm_iter=10, @@ -241,14 +256,21 @@ class BenchmarkRunner: _logger.info('Model %s created, param count: %d' % (model_name, self.param_count)) data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) + self.input_size = data_config['input_size'] + self.batch_size = kwargs.pop('batch_size', 256) + self.scripted = False if torchscript: self.model = torch.jit.script(self.model) self.scripted = True - self.input_size = data_config['input_size'] - self.batch_size = kwargs.pop('batch_size', 256) - - if aot_autograd: + elif dynamo: + assert has_dynamo, "torch._dynamo is needed for --dynamo" + torch._dynamo.reset() + if dynamo_backend is not None: + self.model = torch._dynamo.optimize(dynamo_backend)(self.model) + else: + self.model = torch._dynamo.optimize()(self.model) + elif aot_autograd: assert has_functorch, "functorch is needed for --aot-autograd" self.model = memory_efficient_fusion(self.model) diff --git a/timm/data/config.py b/timm/data/config.py index c5da81f1..a65695d0 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -5,9 +5,15 @@ from .constants import * _logger = logging.getLogger(__name__) -def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False): +def resolve_data_config( + args, + default_cfg=None, + model=None, + use_test_size=False, + verbose=False +): new_config = {} - default_cfg = default_cfg + default_cfg = default_cfg or {} if not default_cfg and model is not None and hasattr(model, 'default_cfg'): default_cfg = model.default_cfg @@ -63,7 +69,7 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v elif default_cfg.get('std', None): new_config['std'] = default_cfg['std'] - # resolve default crop percentage + # resolve default inference crop crop_pct = DEFAULT_CROP_PCT if args.get('crop_pct', None): crop_pct = args['crop_pct'] @@ -74,6 +80,14 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v crop_pct = default_cfg['crop_pct'] new_config['crop_pct'] = crop_pct + # resolve default crop percentage + crop_mode = DEFAULT_CROP_MODE + if args.get('crop_mode', None): + crop_mode = args['crop_mode'] + elif default_cfg.get('crop_mode', None): + crop_mode = default_cfg['crop_mode'] + new_config['crop_mode'] = crop_mode + if verbose: _logger.info('Data processing configuration for current model + dataset:') for n, v in new_config.items(): diff --git a/timm/data/constants.py b/timm/data/constants.py index e4d8bb7e..7d468321 100644 --- a/timm/data/constants.py +++ b/timm/data/constants.py @@ -1,4 +1,5 @@ DEFAULT_CROP_PCT = 0.875 +DEFAULT_CROP_MODE = 'center' IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) diff --git a/timm/data/loader.py b/timm/data/loader.py index 1a4800f8..9d87ed3f 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -211,6 +211,7 @@ def create_loader( num_workers=1, distributed=False, crop_pct=None, + crop_mode=None, collate_fn=None, pin_memory=False, fp16=False, # deprecated, use img_dtype @@ -240,6 +241,7 @@ def create_loader( mean=mean, std=std, crop_pct=crop_pct, + crop_mode=crop_mode, tf_preprocessing=tf_preprocessing, re_prob=re_prob, re_mode=re_mode, diff --git a/timm/data/tf_preprocessing.py b/timm/data/tf_preprocessing.py index 44b4a3af..b58e0cf3 100644 --- a/timm/data/tf_preprocessing.py +++ b/timm/data/tf_preprocessing.py @@ -22,12 +22,13 @@ Hacked together by / Copyright 2020 Ross Wightman # limitations under the License. # ============================================================================== """ImageNet preprocessing for MnasNet.""" -import tensorflow as tf +import tensorflow.compat.v1 as tf import numpy as np IMAGE_SIZE = 224 CROP_PADDING = 32 +tf.compat.v1.disable_eager_execution() def distorted_bounding_box_crop(image_bytes, bbox, diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 3eb3bc32..40ad4380 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -1,3 +1,9 @@ +import math +import numbers +import random +import warnings +from typing import List, Sequence + import torch import torchvision.transforms.functional as F try: @@ -6,9 +12,6 @@ try: except ImportError: has_interpolation_mode = False from PIL import Image -import warnings -import math -import random import numpy as np @@ -96,6 +99,19 @@ def interp_mode_to_str(mode): _RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic')) +def _setup_size(size, error_msg): + if isinstance(size, numbers.Number): + return int(size), int(size) + + if isinstance(size, Sequence) and len(size) == 1: + return size[0], size[0] + + if len(size) != 2: + raise ValueError(error_msg) + + return size + + class RandomResizedCropAndInterpolation: """Crop the given PIL Image to random size and aspect ratio with random interpolation. @@ -195,3 +211,132 @@ class RandomResizedCropAndInterpolation: format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) format_string += ', interpolation={0})'.format(interpolate_str) return format_string + + +def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor: + """Center crops and/or pads the given image. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. + + Args: + img (PIL Image or Tensor): Image to be cropped. + output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int, + it is used for both directions. + fill (int, Tuple[int]): Padding color + + Returns: + PIL Image or Tensor: Cropped image. + """ + if isinstance(output_size, numbers.Number): + output_size = (int(output_size), int(output_size)) + elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: + output_size = (output_size[0], output_size[0]) + + _, image_height, image_width = F.get_dimensions(img) + crop_height, crop_width = output_size + + if crop_width > image_width or crop_height > image_height: + padding_ltrb = [ + (crop_width - image_width) // 2 if crop_width > image_width else 0, + (crop_height - image_height) // 2 if crop_height > image_height else 0, + (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, + (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, + ] + img = F.pad(img, padding_ltrb, fill=fill) + _, image_height, image_width = F.get_dimensions(img) + if crop_width == image_width and crop_height == image_height: + return img + + crop_top = int(round((image_height - crop_height) / 2.0)) + crop_left = int(round((image_width - crop_width) / 2.0)) + return F.crop(img, crop_top, crop_left, crop_height, crop_width) + + +class CenterCropOrPad(torch.nn.Module): + """Crops the given image at the center. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + """ + + def __init__(self, size, fill=0): + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.fill = fill + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ + return center_crop_or_pad(img, self.size, fill=self.fill) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +class ResizeKeepRatio: + """ Resize and Keep Ratio + """ + + def __init__( + self, + size, + longest=0., + interpolation='bilinear', + fill=0, + ): + if isinstance(size, (list, tuple)): + self.size = tuple(size) + else: + self.size = (size, size) + self.interpolation = str_to_interp_mode(interpolation) + self.longest = float(longest) + self.fill = fill + + @staticmethod + def get_params(img, target_size, longest): + """Get parameters + + Args: + img (PIL Image): Image to be cropped. + target_size (Tuple[int, int]): Size of output + Returns: + tuple: params (h, w) and (l, r, t, b) to be passed to ``resize`` and ``pad`` respectively + """ + source_size = img.size[::-1] # h, w + h, w = source_size + target_h, target_w = target_size + ratio_h = h / target_h + ratio_w = w / target_w + ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest) + size = [round(x / ratio) for x in source_size] + return size + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size + """ + size = self.get_params(img, self.size, self.longest) + img = F.resize(img, size, self.interpolation) + return img + + def __repr__(self): + interpolate_str = interp_mode_to_str(self.interpolation) + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += f', interpolation={interpolate_str})' + format_string += f', longest={self.longest:.3f})' + return format_string diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index a5facbf5..6c28383a 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -10,7 +10,8 @@ from torchvision import transforms from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform -from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, ToNumpy +from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation,\ + ResizeKeepRatio, CenterCropOrPad, ToNumpy from timm.data.random_erasing import RandomErasing @@ -130,26 +131,49 @@ def transforms_imagenet_train( def transforms_imagenet_eval( img_size=224, crop_pct=None, + crop_mode=None, interpolation='bilinear', use_prefetcher=False, mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD): + std=IMAGENET_DEFAULT_STD +): crop_pct = crop_pct or DEFAULT_CROP_PCT if isinstance(img_size, (tuple, list)): assert len(img_size) == 2 - if img_size[-1] == img_size[-2]: - # fall-back to older behaviour so Resize scales to shortest edge if target is square - scale_size = int(math.floor(img_size[0] / crop_pct)) - else: - scale_size = tuple([int(x / crop_pct) for x in img_size]) + scale_size = tuple([math.floor(x / crop_pct) for x in img_size]) else: - scale_size = int(math.floor(img_size / crop_pct)) + scale_size = math.floor(img_size / crop_pct) + scale_size = (scale_size, scale_size) + + if crop_mode == 'squash': + # squash mode scales each edge to 1/pct of target, then crops + # aspect ratio is not preserved, no img lost if crop_pct == 1.0 + tfl = [ + transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)), + transforms.CenterCrop(img_size), + ] + elif crop_mode == 'border': + # scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop + # no image lost if crop_pct == 1.0 + fill = [round(255 * v) for v in mean] + tfl = [ + ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0), + CenterCropOrPad(img_size, fill=fill), + ] + else: + # default crop model is center + # aspect ratio is preserved, crops center within image, no borders are added, image is lost + if scale_size[0] == scale_size[1]: + # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg) + tfl = [ + transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation)) + ] + else: + # resize shortest edge to matching target dim for non-square target + tfl = [ResizeKeepRatio(scale_size)] + tfl += [transforms.CenterCrop(img_size)] - tfl = [ - transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)), - transforms.CenterCrop(img_size), - ] if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm tfl += [ToNumpy()] @@ -157,8 +181,9 @@ def transforms_imagenet_eval( tfl += [ transforms.ToTensor(), transforms.Normalize( - mean=torch.tensor(mean), - std=torch.tensor(std)) + mean=torch.tensor(mean), + std=torch.tensor(std), + ) ] return transforms.Compose(tfl) @@ -183,6 +208,7 @@ def create_transform( re_count=1, re_num_splits=0, crop_pct=None, + crop_mode=None, tf_preprocessing=False, separate=False): @@ -204,7 +230,8 @@ def create_transform( interpolation=interpolation, use_prefetcher=use_prefetcher, mean=mean, - std=std) + std=std, + ) elif is_training: transform = transforms_imagenet_train( img_size, @@ -222,7 +249,8 @@ def create_transform( re_mode=re_mode, re_count=re_count, re_num_splits=re_num_splits, - separate=separate) + separate=separate, + ) else: assert not separate, "Separate transforms not supported for validation preprocessing" transform = transforms_imagenet_eval( @@ -231,6 +259,8 @@ def create_transform( use_prefetcher=use_prefetcher, mean=mean, std=std, - crop_pct=crop_pct) + crop_pct=crop_pct, + crop_mode=crop_mode, + ) return transform diff --git a/timm/models/_pretrained.py b/timm/models/_pretrained.py index 8a63ebe9..61fb718c 100644 --- a/timm/models/_pretrained.py +++ b/timm/models/_pretrained.py @@ -25,7 +25,7 @@ class PretrainedCfg: interpolation: str = 'bicubic' crop_pct: float = 0.875 test_crop_pct: Optional[float] = None - crop_type: str = 'pct' + crop_mode: str = 'center' mean: Tuple[float, ...] = (0.485, 0.456, 0.406) std: Tuple[float, ...] = (0.229, 0.224, 0.225) diff --git a/train.py b/train.py index e004881d..b6027f1d 100755 --- a/train.py +++ b/train.py @@ -66,6 +66,13 @@ try: except ImportError as e: has_functorch = False +try: + import torch._dynamo + has_dynamo = True +except ImportError: + has_dynamo = False + pass + _logger = logging.getLogger('train') @@ -130,17 +137,22 @@ group.add_argument('-vb', '--validation-batch-size', type=int, default=None, met help='Validation batch size override (default: None)') group.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') -scripting_group = group.add_mutually_exclusive_group() -scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', - help='torch.jit.script the full model') -scripting_group.add_argument('--aot-autograd', default=False, action='store_true', - help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") group.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") -group.add_argument('--fast-norm', default=False, action='store_true', - help='enable experimental fast-norm') group.add_argument('--grad-checkpointing', action='store_true', default=False, help='Enable gradient checkpointing through model blocks/stages') +group.add_argument('--fast-norm', default=False, action='store_true', + help='enable experimental fast-norm') +parser.add_argument('--dynamo-backend', default=None, type=str, + help="Select dynamo backend. Default: None") + +scripting_group = group.add_mutually_exclusive_group() +scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', + help='torch.jit.script the full model') +scripting_group.add_argument('--aot-autograd', default=False, action='store_true', + help="Enable AOT Autograd support.") +scripting_group.add_argument('--dynamo', default=False, action='store_true', + help="Enable Dynamo optimization.") # Optimizer parameters group = parser.add_argument_group('Optimizer parameters') @@ -473,10 +485,16 @@ def main(): assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) - - if args.aot_autograd: + elif args.aot_autograd: assert has_functorch, "functorch is needed for --aot-autograd" model = memory_efficient_fusion(model) + elif args.dynamo: + # FIXME dynamo might need move below DDP wrapping? TBD + assert has_dynamo, "torch._dynamo is needed for --dynamo" + if args.dynamo_backend is not None: + model = torch._dynamo.optimize(args.dynamo_backend)(model) + else: + model = torch._dynamo.optimize()(model) if args.lr is None: global_batch_size = args.batch_size * args.world_size diff --git a/validate.py b/validate.py index 1a1ea9cd..e0f42e03 100755 --- a/validate.py +++ b/validate.py @@ -46,6 +46,13 @@ try: except ImportError as e: has_functorch = False +try: + import torch._dynamo + has_dynamo = True +except ImportError: + has_dynamo = False + pass + _logger = logging.getLogger('validate') @@ -72,6 +79,8 @@ parser.add_argument('--use-train-size', action='store_true', default=False, help='force use of train input size, even when test size is specified in pretrained cfg') parser.add_argument('--crop-pct', default=None, type=float, metavar='N', help='Input image center crop pct') +parser.add_argument('--crop-mode', default=None, type=str, + metavar='N', help='Input image crop mode (squash, border, center). Model default if None.') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', @@ -112,15 +121,21 @@ parser.add_argument('--tf-preprocessing', action='store_true', default=False, help='Use Tensorflow preprocessing pipeline (require CPU TF installed') parser.add_argument('--use-ema', dest='use_ema', action='store_true', help='use ema version of weights if present') -scripting_group = parser.add_mutually_exclusive_group() -scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', - help='torch.jit.script the full model') -scripting_group.add_argument('--aot-autograd', default=False, action='store_true', - help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--fast-norm', default=False, action='store_true', help='enable experimental fast-norm') +parser.add_argument('--dynamo-backend', default=None, type=str, + help="Select dynamo backend. Default: None") + +scripting_group = parser.add_mutually_exclusive_group() +scripting_group.add_argument('--torchscript', default=False, action='store_true', + help='torch.jit.script the full model') +scripting_group.add_argument('--aot-autograd', default=False, action='store_true', + help="Enable AOT Autograd support.") +scripting_group.add_argument('--dynamo', default=False, action='store_true', + help="Enable Dynamo optimization.") + parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', help='Output csv file for validation results (summary)') parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', @@ -196,21 +211,27 @@ def validate(args): if args.test_pool: model, test_time_pool = apply_test_time_pool(model, data_config) + model = model.to(device) + if args.channels_last: + model = model.to(memory_format=torch.channels_last) + if args.torchscript: - torch.jit.optimized_execution(True) + assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' model = torch.jit.script(model) - - if args.aot_autograd: + elif args.aot_autograd: assert has_functorch, "functorch is needed for --aot-autograd" model = memory_efficient_fusion(model) + elif args.dynamo: + assert has_dynamo, "torch._dynamo is needed for --dynamo" + torch._dynamo.reset() + if args.dynamo_backend is not None: + model = torch._dynamo.optimize(args.dynamo_backend)(model) + else: + model = torch._dynamo.optimize()(model) - model = model.to(device) if use_amp == 'apex': model = amp.initialize(model, opt_level='O1') - if args.channels_last: - model = model.to(memory_format=torch.channels_last) - if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) @@ -248,6 +269,7 @@ def validate(args): std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, + crop_mode=data_config['crop_mode'], pin_memory=args.pin_mem, device=device, tf_preprocessing=args.tf_preprocessing, @@ -376,7 +398,7 @@ def main(): model_cfgs = [(n, '') for n in model_names] elif not is_model(args.model): # model name doesn't exist, try as wildcard filter - model_names = list_models(args.model) + model_names = list_models(args.model, pretrained=True) model_cfgs = [(n, '') for n in model_names] if not model_cfgs and os.path.isfile(args.model):