Update benchmark script to add precision arg. Fix some downstream (DeiT) compat issues with latest changes. Bump version to 0.4.7

pull/533/head
Ross Wightman 4 years ago
parent ea9c9550b2
commit 288682796f

@ -19,7 +19,7 @@ from contextlib import suppress
from functools import partial
from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer
from timm.optim import create_optimizer_v2
from timm.data import resolve_data_config
from timm.utils import AverageMeter, setup_default_logging
@ -53,6 +53,10 @@ parser.add_argument('--detail', action='store_true', default=False,
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument('--num-warm-iter', default=10, type=int,
metavar='N', help='Number of warmup iterations (default: 10)')
parser.add_argument('--num-bench-iter', default=40, type=int,
metavar='N', help='Number of benchmark iterations (default: 40)')
# common inference / train args
parser.add_argument('--model', '-m', metavar='NAME', default='resnet50',
@ -70,11 +74,9 @@ parser.add_argument('--gp', default=None, type=str, metavar='POOL',
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--amp', action='store_true', default=False,
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
parser.add_argument('--precision', default='float32', type=str,
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
@ -117,28 +119,50 @@ def cuda_timestamp(sync=False, device=None):
return time.perf_counter()
def count_params(model):
def count_params(model: nn.Module):
return sum([m.numel() for m in model.parameters()])
def resolve_precision(precision: str):
assert precision in ('amp', 'float16', 'bfloat16', 'float32')
use_amp = False
model_dtype = torch.float32
data_dtype = torch.float32
if precision == 'amp':
use_amp = True
elif precision == 'float16':
model_dtype = torch.float16
data_dtype = torch.float16
elif precision == 'bfloat16':
model_dtype = torch.bfloat16
data_dtype = torch.bfloat16
return use_amp, model_dtype, data_dtype
class BenchmarkRunner:
def __init__(self, model_name, detail=False, device='cuda', torchscript=False, **kwargs):
def __init__(
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
num_warm_iter=10, num_bench_iter=50, **kwargs):
self.model_name = model_name
self.detail = detail
self.device = device
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.channels_last = kwargs.pop('channels_last', False)
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
self.model = create_model(
model_name,
num_classes=kwargs.pop('num_classes', None),
in_chans=3,
global_pool=kwargs.pop('gp', 'fast'),
scriptable=torchscript).to(device=self.device)
scriptable=torchscript)
self.model.to(
device=self.device,
dtype=self.model_dtype,
memory_format=torch.channels_last if self.channels_last else None)
self.num_classes = self.model.num_classes
self.param_count = count_params(self.model)
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
self.channels_last = kwargs.pop('channels_last', False)
self.use_amp = kwargs.pop('use_amp', '')
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp == 'native' else suppress
if torchscript:
self.model = torch.jit.script(self.model)
@ -147,16 +171,17 @@ class BenchmarkRunner:
self.batch_size = kwargs.pop('batch_size', 256)
self.example_inputs = None
self.num_warm_iter = 10
self.num_bench_iter = 50
self.log_freq = 10
self.num_warm_iter = num_warm_iter
self.num_bench_iter = num_bench_iter
self.log_freq = num_bench_iter // 5
if 'cuda' in self.device:
self.time_fn = partial(cuda_timestamp, device=self.device)
else:
self.time_fn = timestamp
def _init_input(self):
self.example_inputs = torch.randn((self.batch_size,) + self.input_size, device=self.device)
self.example_inputs = torch.randn(
(self.batch_size,) + self.input_size, device=self.device, dtype=self.data_dtype)
if self.channels_last:
self.example_inputs = self.example_inputs.contiguous(memory_format=torch.channels_last)
@ -166,10 +191,6 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
self.model.eval()
if self.use_amp == 'apex':
self.model = amp.initialize(self.model, opt_level='O1')
if self.channels_last:
self.model = self.model.to(memory_format=torch.channels_last)
def run(self):
def _step():
@ -231,16 +252,11 @@ class TrainBenchmarkRunner(BenchmarkRunner):
self.loss = nn.CrossEntropyLoss().to(self.device)
self.target_shape = tuple()
self.optimizer = create_optimizer(
self.optimizer = create_optimizer_v2(
self.model,
opt_name=kwargs.pop('opt', 'sgd'),
lr=kwargs.pop('lr', 1e-4))
if self.use_amp == 'apex':
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1')
if self.channels_last:
self.model = self.model.to(memory_format=torch.channels_last)
def _gen_target(self, batch_size):
return torch.empty(
(batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes)
@ -331,6 +347,7 @@ class TrainBenchmarkRunner(BenchmarkRunner):
samples_per_sec=round(num_samples / t_run_elapsed, 2),
step_time=round(1000 * total_step / num_samples, 3),
batch_size=self.batch_size,
img_size=self.input_size[-1],
param_count=round(self.param_count / 1e6, 2),
)
@ -367,23 +384,14 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
def benchmark(args):
if args.amp:
if has_native_amp:
args.native_amp = True
elif has_apex:
args.apex_amp = True
else:
_logger.warning("Neither APEX or Native Torch AMP is available.")
if args.native_amp:
args.use_amp = 'native'
_logger.info('Benchmarking in mixed precision with native PyTorch AMP.')
elif args.apex_amp:
args.use_amp = 'apex'
_logger.info('Benchmarking in mixed precision with NVIDIA APEX AMP.')
else:
args.use_amp = ''
_logger.info('Benchmarking in float32. AMP not enabled.')
_logger.warning("Overriding precision to 'amp' since --amp flag set.")
args.precision = 'amp'
_logger.info(f'Benchmarking in {args.precision} precision. '
f'{"NHWC" if args.channels_last else "NCHW"} layout. '
f'torchscript {"enabled" if args.torchscript else "disabled"}')
bench_kwargs = vars(args).copy()
bench_kwargs.pop('amp')
model = bench_kwargs.pop('model')
batch_size = bench_kwargs.pop('batch_size')

@ -89,7 +89,7 @@ default_cfgs = dict(
regnety_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'),
regnety_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'),
regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'),
regnety_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth'),
regnety_160=_cfg(url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth'),
regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'),
)

@ -281,8 +281,9 @@ class VisionTransformer(nn.Module):
# Classifier head(s)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
if num_classes > 0 and distilled else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
# Weight init
assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
@ -336,8 +337,8 @@ class VisionTransformer(nn.Module):
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
if num_classes > 0 and self.dist_token is not None else nn.Identity()
if self.head_dist is not None:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
@ -356,8 +357,8 @@ class VisionTransformer(nn.Module):
def forward(self, x):
x = self.forward_features(x)
if isinstance(x, tuple):
x, x_dist = self.head(x[0]), self.head_dist(x[1])
if self.head_dist is not None:
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
if self.training and not torch.jit.is_scripting():
# during inference, return the average of both classifier predictions
return x, x_dist

@ -145,6 +145,12 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
return model
@register_model
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
# NOTE this is forwarding to model def above for backwards compatibility
return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs)
@register_model
def vit_base_r50_s16_384(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
@ -157,6 +163,12 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs):
return model
@register_model
def vit_base_resnet50_384(pretrained=False, **kwargs):
# NOTE this is forwarding to model def above for backwards compatibility
return vit_base_r50_s16_384(pretrained=pretrained, **kwargs)
@register_model
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.

@ -10,4 +10,4 @@ from .radam import RAdam
from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP
from .optim_factory import create_optimizer, optimizer_kwargs
from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs

@ -55,7 +55,21 @@ def optimizer_kwargs(cfg):
return kwargs
def create_optimizer(
def create_optimizer(args, model, filter_bias_and_bn=True):
""" Legacy optimizer factory for backwards compatibility.
NOTE: Use create_optimizer_v2 for new code.
"""
opt_args = dict(lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
opt_args['eps'] = args.opt_eps
if hasattr(args, 'opt_betas') and args.opt_betas is not None:
opt_args['betas'] = args.opt_betas
if hasattr(args, 'opt_args') and args.opt_args is not None:
opt_args.update(args.opt_args)
return create_optimizer_v2(model, opt_name=args.opt, filter_bias_and_bn=filter_bias_and_bn, **opt_args)
def create_optimizer_v2(
model: nn.Module,
opt_name: str = 'sgd',
lr: Optional[float] = None,

@ -1 +1 @@
__version__ = '0.4.6'
__version__ = '0.4.7'

@ -33,7 +33,7 @@ from timm.models import create_model, safe_model_name, resume_checkpoint, load_c
convert_splitbn_model, model_parameters
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer, optimizer_kwargs
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
@ -389,7 +389,7 @@ def main():
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model)
optimizer = create_optimizer(model, **optimizer_kwargs(cfg=args))
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing

Loading…
Cancel
Save