Add crop_mode for pretraind config / image transforms. Add support for dynamo compilation to benchmark/train/validate

pull/1582/head
Ross Wightman 2 years ago committed by Ross Wightman
parent 8fca002c06
commit 9da7e3a799

@ -56,6 +56,13 @@ try:
except ImportError as e: except ImportError as e:
has_functorch = False has_functorch = False
try:
import torch._dynamo
has_dynamo = True
except ImportError:
has_dynamo = False
pass
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -106,13 +113,19 @@ parser.add_argument('--precision', default='float32', type=str,
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)') help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
parser.add_argument('--fuser', default='', type=str, parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--dynamo-backend', default=None, type=str,
help="Select dynamo backend. Default: None")
parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
# codegen (model compilation) options
scripting_group = parser.add_mutually_exclusive_group() scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference') help='convert model torchscript for inference')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true', scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") help="Enable AOT Autograd optimization.")
scripting_group.add_argument('--fast-norm', default=False, action='store_true', scripting_group.add_argument('--dynamo', default=False, action='store_true',
help='enable experimental fast-norm') help="Enable Dynamo optimization.")
# train optimizer parameters # train optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
@ -206,6 +219,8 @@ class BenchmarkRunner:
device='cuda', device='cuda',
torchscript=False, torchscript=False,
aot_autograd=False, aot_autograd=False,
dynamo=False,
dynamo_backend=None,
precision='float32', precision='float32',
fuser='', fuser='',
num_warm_iter=10, num_warm_iter=10,
@ -241,14 +256,21 @@ class BenchmarkRunner:
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count)) _logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
self.input_size = data_config['input_size']
self.batch_size = kwargs.pop('batch_size', 256)
self.scripted = False self.scripted = False
if torchscript: if torchscript:
self.model = torch.jit.script(self.model) self.model = torch.jit.script(self.model)
self.scripted = True self.scripted = True
self.input_size = data_config['input_size'] elif dynamo:
self.batch_size = kwargs.pop('batch_size', 256) assert has_dynamo, "torch._dynamo is needed for --dynamo"
torch._dynamo.reset()
if aot_autograd: if dynamo_backend is not None:
self.model = torch._dynamo.optimize(dynamo_backend)(self.model)
else:
self.model = torch._dynamo.optimize()(self.model)
elif aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd" assert has_functorch, "functorch is needed for --aot-autograd"
self.model = memory_efficient_fusion(self.model) self.model = memory_efficient_fusion(self.model)

@ -5,9 +5,15 @@ from .constants import *
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False): def resolve_data_config(
args,
default_cfg=None,
model=None,
use_test_size=False,
verbose=False
):
new_config = {} new_config = {}
default_cfg = default_cfg default_cfg = default_cfg or {}
if not default_cfg and model is not None and hasattr(model, 'default_cfg'): if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
default_cfg = model.default_cfg default_cfg = model.default_cfg
@ -63,7 +69,7 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v
elif default_cfg.get('std', None): elif default_cfg.get('std', None):
new_config['std'] = default_cfg['std'] new_config['std'] = default_cfg['std']
# resolve default crop percentage # resolve default inference crop
crop_pct = DEFAULT_CROP_PCT crop_pct = DEFAULT_CROP_PCT
if args.get('crop_pct', None): if args.get('crop_pct', None):
crop_pct = args['crop_pct'] crop_pct = args['crop_pct']
@ -74,6 +80,14 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v
crop_pct = default_cfg['crop_pct'] crop_pct = default_cfg['crop_pct']
new_config['crop_pct'] = crop_pct new_config['crop_pct'] = crop_pct
# resolve default crop percentage
crop_mode = DEFAULT_CROP_MODE
if args.get('crop_mode', None):
crop_mode = args['crop_mode']
elif default_cfg.get('crop_mode', None):
crop_mode = default_cfg['crop_mode']
new_config['crop_mode'] = crop_mode
if verbose: if verbose:
_logger.info('Data processing configuration for current model + dataset:') _logger.info('Data processing configuration for current model + dataset:')
for n, v in new_config.items(): for n, v in new_config.items():

@ -1,4 +1,5 @@
DEFAULT_CROP_PCT = 0.875 DEFAULT_CROP_PCT = 0.875
DEFAULT_CROP_MODE = 'center'
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)

@ -211,6 +211,7 @@ def create_loader(
num_workers=1, num_workers=1,
distributed=False, distributed=False,
crop_pct=None, crop_pct=None,
crop_mode=None,
collate_fn=None, collate_fn=None,
pin_memory=False, pin_memory=False,
fp16=False, # deprecated, use img_dtype fp16=False, # deprecated, use img_dtype
@ -240,6 +241,7 @@ def create_loader(
mean=mean, mean=mean,
std=std, std=std,
crop_pct=crop_pct, crop_pct=crop_pct,
crop_mode=crop_mode,
tf_preprocessing=tf_preprocessing, tf_preprocessing=tf_preprocessing,
re_prob=re_prob, re_prob=re_prob,
re_mode=re_mode, re_mode=re_mode,

@ -22,12 +22,13 @@ Hacked together by / Copyright 2020 Ross Wightman
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""ImageNet preprocessing for MnasNet.""" """ImageNet preprocessing for MnasNet."""
import tensorflow as tf import tensorflow.compat.v1 as tf
import numpy as np import numpy as np
IMAGE_SIZE = 224 IMAGE_SIZE = 224
CROP_PADDING = 32 CROP_PADDING = 32
tf.compat.v1.disable_eager_execution()
def distorted_bounding_box_crop(image_bytes, def distorted_bounding_box_crop(image_bytes,
bbox, bbox,

@ -1,3 +1,9 @@
import math
import numbers
import random
import warnings
from typing import List, Sequence
import torch import torch
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
try: try:
@ -6,9 +12,6 @@ try:
except ImportError: except ImportError:
has_interpolation_mode = False has_interpolation_mode = False
from PIL import Image from PIL import Image
import warnings
import math
import random
import numpy as np import numpy as np
@ -96,6 +99,19 @@ def interp_mode_to_str(mode):
_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic')) _RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
def _setup_size(size, error_msg):
if isinstance(size, numbers.Number):
return int(size), int(size)
if isinstance(size, Sequence) and len(size) == 1:
return size[0], size[0]
if len(size) != 2:
raise ValueError(error_msg)
return size
class RandomResizedCropAndInterpolation: class RandomResizedCropAndInterpolation:
"""Crop the given PIL Image to random size and aspect ratio with random interpolation. """Crop the given PIL Image to random size and aspect ratio with random interpolation.
@ -195,3 +211,132 @@ class RandomResizedCropAndInterpolation:
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0})'.format(interpolate_str) format_string += ', interpolation={0})'.format(interpolate_str)
return format_string return format_string
def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
"""Center crops and/or pads the given image.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
img (PIL Image or Tensor): Image to be cropped.
output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
it is used for both directions.
fill (int, Tuple[int]): Padding color
Returns:
PIL Image or Tensor: Cropped image.
"""
if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size))
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])
_, image_height, image_width = F.get_dimensions(img)
crop_height, crop_width = output_size
if crop_width > image_width or crop_height > image_height:
padding_ltrb = [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2 if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
img = F.pad(img, padding_ltrb, fill=fill)
_, image_height, image_width = F.get_dimensions(img)
if crop_width == image_width and crop_height == image_height:
return img
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return F.crop(img, crop_top, crop_left, crop_height, crop_width)
class CenterCropOrPad(torch.nn.Module):
"""Crops the given image at the center.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
"""
def __init__(self, size, fill=0):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.fill = fill
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
PIL Image or Tensor: Cropped image.
"""
return center_crop_or_pad(img, self.size, fill=self.fill)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class ResizeKeepRatio:
""" Resize and Keep Ratio
"""
def __init__(
self,
size,
longest=0.,
interpolation='bilinear',
fill=0,
):
if isinstance(size, (list, tuple)):
self.size = tuple(size)
else:
self.size = (size, size)
self.interpolation = str_to_interp_mode(interpolation)
self.longest = float(longest)
self.fill = fill
@staticmethod
def get_params(img, target_size, longest):
"""Get parameters
Args:
img (PIL Image): Image to be cropped.
target_size (Tuple[int, int]): Size of output
Returns:
tuple: params (h, w) and (l, r, t, b) to be passed to ``resize`` and ``pad`` respectively
"""
source_size = img.size[::-1] # h, w
h, w = source_size
target_h, target_w = target_size
ratio_h = h / target_h
ratio_w = w / target_w
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
size = [round(x / ratio) for x in source_size]
return size
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
"""
size = self.get_params(img, self.size, self.longest)
img = F.resize(img, size, self.interpolation)
return img
def __repr__(self):
interpolate_str = interp_mode_to_str(self.interpolation)
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += f', interpolation={interpolate_str})'
format_string += f', longest={self.longest:.3f})'
return format_string

@ -10,7 +10,8 @@ from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, ToNumpy from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation,\
ResizeKeepRatio, CenterCropOrPad, ToNumpy
from timm.data.random_erasing import RandomErasing from timm.data.random_erasing import RandomErasing
@ -130,26 +131,49 @@ def transforms_imagenet_train(
def transforms_imagenet_eval( def transforms_imagenet_eval(
img_size=224, img_size=224,
crop_pct=None, crop_pct=None,
crop_mode=None,
interpolation='bilinear', interpolation='bilinear',
use_prefetcher=False, use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD): std=IMAGENET_DEFAULT_STD
):
crop_pct = crop_pct or DEFAULT_CROP_PCT crop_pct = crop_pct or DEFAULT_CROP_PCT
if isinstance(img_size, (tuple, list)): if isinstance(img_size, (tuple, list)):
assert len(img_size) == 2 assert len(img_size) == 2
if img_size[-1] == img_size[-2]: scale_size = tuple([math.floor(x / crop_pct) for x in img_size])
# fall-back to older behaviour so Resize scales to shortest edge if target is square
scale_size = int(math.floor(img_size[0] / crop_pct))
else:
scale_size = tuple([int(x / crop_pct) for x in img_size])
else: else:
scale_size = int(math.floor(img_size / crop_pct)) scale_size = math.floor(img_size / crop_pct)
scale_size = (scale_size, scale_size)
if crop_mode == 'squash':
# squash mode scales each edge to 1/pct of target, then crops
# aspect ratio is not preserved, no img lost if crop_pct == 1.0
tfl = [ tfl = [
transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)), transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
transforms.CenterCrop(img_size), transforms.CenterCrop(img_size),
] ]
elif crop_mode == 'border':
# scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop
# no image lost if crop_pct == 1.0
fill = [round(255 * v) for v in mean]
tfl = [
ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0),
CenterCropOrPad(img_size, fill=fill),
]
else:
# default crop model is center
# aspect ratio is preserved, crops center within image, no borders are added, image is lost
if scale_size[0] == scale_size[1]:
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
tfl = [
transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation))
]
else:
# resize shortest edge to matching target dim for non-square target
tfl = [ResizeKeepRatio(scale_size)]
tfl += [transforms.CenterCrop(img_size)]
if use_prefetcher: if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm # prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()] tfl += [ToNumpy()]
@ -158,7 +182,8 @@ def transforms_imagenet_eval(
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(
mean=torch.tensor(mean), mean=torch.tensor(mean),
std=torch.tensor(std)) std=torch.tensor(std),
)
] ]
return transforms.Compose(tfl) return transforms.Compose(tfl)
@ -183,6 +208,7 @@ def create_transform(
re_count=1, re_count=1,
re_num_splits=0, re_num_splits=0,
crop_pct=None, crop_pct=None,
crop_mode=None,
tf_preprocessing=False, tf_preprocessing=False,
separate=False): separate=False):
@ -204,7 +230,8 @@ def create_transform(
interpolation=interpolation, interpolation=interpolation,
use_prefetcher=use_prefetcher, use_prefetcher=use_prefetcher,
mean=mean, mean=mean,
std=std) std=std,
)
elif is_training: elif is_training:
transform = transforms_imagenet_train( transform = transforms_imagenet_train(
img_size, img_size,
@ -222,7 +249,8 @@ def create_transform(
re_mode=re_mode, re_mode=re_mode,
re_count=re_count, re_count=re_count,
re_num_splits=re_num_splits, re_num_splits=re_num_splits,
separate=separate) separate=separate,
)
else: else:
assert not separate, "Separate transforms not supported for validation preprocessing" assert not separate, "Separate transforms not supported for validation preprocessing"
transform = transforms_imagenet_eval( transform = transforms_imagenet_eval(
@ -231,6 +259,8 @@ def create_transform(
use_prefetcher=use_prefetcher, use_prefetcher=use_prefetcher,
mean=mean, mean=mean,
std=std, std=std,
crop_pct=crop_pct) crop_pct=crop_pct,
crop_mode=crop_mode,
)
return transform return transform

@ -25,7 +25,7 @@ class PretrainedCfg:
interpolation: str = 'bicubic' interpolation: str = 'bicubic'
crop_pct: float = 0.875 crop_pct: float = 0.875
test_crop_pct: Optional[float] = None test_crop_pct: Optional[float] = None
crop_type: str = 'pct' crop_mode: str = 'center'
mean: Tuple[float, ...] = (0.485, 0.456, 0.406) mean: Tuple[float, ...] = (0.485, 0.456, 0.406)
std: Tuple[float, ...] = (0.229, 0.224, 0.225) std: Tuple[float, ...] = (0.229, 0.224, 0.225)

@ -66,6 +66,13 @@ try:
except ImportError as e: except ImportError as e:
has_functorch = False has_functorch = False
try:
import torch._dynamo
has_dynamo = True
except ImportError:
has_dynamo = False
pass
_logger = logging.getLogger('train') _logger = logging.getLogger('train')
@ -130,17 +137,22 @@ group.add_argument('-vb', '--validation-batch-size', type=int, default=None, met
help='Validation batch size override (default: None)') help='Validation batch size override (default: None)')
group.add_argument('--channels-last', action='store_true', default=False, group.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout') help='Use channels_last memory layout')
scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
help='torch.jit.script the full model')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
group.add_argument('--fuser', default='', type=str, group.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
group.add_argument('--grad-checkpointing', action='store_true', default=False, group.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages') help='Enable gradient checkpointing through model blocks/stages')
group.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
parser.add_argument('--dynamo-backend', default=None, type=str,
help="Select dynamo backend. Default: None")
scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
help='torch.jit.script the full model')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support.")
scripting_group.add_argument('--dynamo', default=False, action='store_true',
help="Enable Dynamo optimization.")
# Optimizer parameters # Optimizer parameters
group = parser.add_argument_group('Optimizer parameters') group = parser.add_argument_group('Optimizer parameters')
@ -473,10 +485,16 @@ def main():
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model) model = torch.jit.script(model)
elif args.aot_autograd:
if args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd" assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model) model = memory_efficient_fusion(model)
elif args.dynamo:
# FIXME dynamo might need move below DDP wrapping? TBD
assert has_dynamo, "torch._dynamo is needed for --dynamo"
if args.dynamo_backend is not None:
model = torch._dynamo.optimize(args.dynamo_backend)(model)
else:
model = torch._dynamo.optimize()(model)
if args.lr is None: if args.lr is None:
global_batch_size = args.batch_size * args.world_size global_batch_size = args.batch_size * args.world_size

@ -46,6 +46,13 @@ try:
except ImportError as e: except ImportError as e:
has_functorch = False has_functorch = False
try:
import torch._dynamo
has_dynamo = True
except ImportError:
has_dynamo = False
pass
_logger = logging.getLogger('validate') _logger = logging.getLogger('validate')
@ -72,6 +79,8 @@ parser.add_argument('--use-train-size', action='store_true', default=False,
help='force use of train input size, even when test size is specified in pretrained cfg') help='force use of train input size, even when test size is specified in pretrained cfg')
parser.add_argument('--crop-pct', default=None, type=float, parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop pct') metavar='N', help='Input image center crop pct')
parser.add_argument('--crop-mode', default=None, type=str,
metavar='N', help='Input image crop mode (squash, border, center). Model default if None.')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset') help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
@ -112,15 +121,21 @@ parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed') help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true', parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present') help='use ema version of weights if present')
scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
help='torch.jit.script the full model')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
parser.add_argument('--fuser', default='', type=str, parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true', parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm') help='enable experimental fast-norm')
parser.add_argument('--dynamo-backend', default=None, type=str,
help="Select dynamo backend. Default: None")
scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true',
help='torch.jit.script the full model')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support.")
scripting_group.add_argument('--dynamo', default=False, action='store_true',
help="Enable Dynamo optimization.")
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)') help='Output csv file for validation results (summary)')
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
@ -196,21 +211,27 @@ def validate(args):
if args.test_pool: if args.test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config) model, test_time_pool = apply_test_time_pool(model, data_config)
model = model.to(device)
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
if args.torchscript: if args.torchscript:
torch.jit.optimized_execution(True) assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
model = torch.jit.script(model) model = torch.jit.script(model)
elif args.aot_autograd:
if args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd" assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model) model = memory_efficient_fusion(model)
elif args.dynamo:
assert has_dynamo, "torch._dynamo is needed for --dynamo"
torch._dynamo.reset()
if args.dynamo_backend is not None:
model = torch._dynamo.optimize(args.dynamo_backend)(model)
else:
model = torch._dynamo.optimize()(model)
model = model.to(device)
if use_amp == 'apex': if use_amp == 'apex':
model = amp.initialize(model, opt_level='O1') model = amp.initialize(model, opt_level='O1')
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
if args.num_gpu > 1: if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
@ -248,6 +269,7 @@ def validate(args):
std=data_config['std'], std=data_config['std'],
num_workers=args.workers, num_workers=args.workers,
crop_pct=crop_pct, crop_pct=crop_pct,
crop_mode=data_config['crop_mode'],
pin_memory=args.pin_mem, pin_memory=args.pin_mem,
device=device, device=device,
tf_preprocessing=args.tf_preprocessing, tf_preprocessing=args.tf_preprocessing,
@ -376,7 +398,7 @@ def main():
model_cfgs = [(n, '') for n in model_names] model_cfgs = [(n, '') for n in model_names]
elif not is_model(args.model): elif not is_model(args.model):
# model name doesn't exist, try as wildcard filter # model name doesn't exist, try as wildcard filter
model_names = list_models(args.model) model_names = list_models(args.model, pretrained=True)
model_cfgs = [(n, '') for n in model_names] model_cfgs = [(n, '') for n in model_names]
if not model_cfgs and os.path.isfile(args.model): if not model_cfgs and os.path.isfile(args.model):

Loading…
Cancel
Save