diff --git a/benchmark.py b/benchmark.py index e692eacc..5f296c24 100755 --- a/benchmark.py +++ b/benchmark.py @@ -19,7 +19,7 @@ from contextlib import suppress from functools import partial from timm.models import create_model, is_model, list_models -from timm.optim import create_optimizer +from timm.optim import create_optimizer_v2 from timm.data import resolve_data_config from timm.utils import AverageMeter, setup_default_logging @@ -53,6 +53,10 @@ parser.add_argument('--detail', action='store_true', default=False, help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False') parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', help='Output csv file for validation results (summary)') +parser.add_argument('--num-warm-iter', default=10, type=int, + metavar='N', help='Number of warmup iterations (default: 10)') +parser.add_argument('--num-bench-iter', default=40, type=int, + metavar='N', help='Number of benchmark iterations (default: 40)') # common inference / train args parser.add_argument('--model', '-m', metavar='NAME', default='resnet50', @@ -70,11 +74,9 @@ parser.add_argument('--gp', default=None, type=str, metavar='POOL', parser.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') parser.add_argument('--amp', action='store_true', default=False, - help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') -parser.add_argument('--apex-amp', action='store_true', default=False, - help='Use NVIDIA Apex AMP mixed precision') -parser.add_argument('--native-amp', action='store_true', default=False, - help='Use Native Torch AMP mixed precision') + help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.') +parser.add_argument('--precision', default='float32', type=str, + help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') @@ -117,28 +119,50 @@ def cuda_timestamp(sync=False, device=None): return time.perf_counter() -def count_params(model): +def count_params(model: nn.Module): return sum([m.numel() for m in model.parameters()]) +def resolve_precision(precision: str): + assert precision in ('amp', 'float16', 'bfloat16', 'float32') + use_amp = False + model_dtype = torch.float32 + data_dtype = torch.float32 + if precision == 'amp': + use_amp = True + elif precision == 'float16': + model_dtype = torch.float16 + data_dtype = torch.float16 + elif precision == 'bfloat16': + model_dtype = torch.bfloat16 + data_dtype = torch.bfloat16 + return use_amp, model_dtype, data_dtype + + class BenchmarkRunner: - def __init__(self, model_name, detail=False, device='cuda', torchscript=False, **kwargs): + def __init__( + self, model_name, detail=False, device='cuda', torchscript=False, precision='float32', + num_warm_iter=10, num_bench_iter=50, **kwargs): self.model_name = model_name self.detail = detail self.device = device + self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision) + self.channels_last = kwargs.pop('channels_last', False) + self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress + self.model = create_model( model_name, num_classes=kwargs.pop('num_classes', None), in_chans=3, global_pool=kwargs.pop('gp', 'fast'), - scriptable=torchscript).to(device=self.device) + scriptable=torchscript) + self.model.to( + device=self.device, + dtype=self.model_dtype, + memory_format=torch.channels_last if self.channels_last else None) self.num_classes = self.model.num_classes self.param_count = count_params(self.model) _logger.info('Model %s created, param count: %d' % (model_name, self.param_count)) - - self.channels_last = kwargs.pop('channels_last', False) - self.use_amp = kwargs.pop('use_amp', '') - self.amp_autocast = torch.cuda.amp.autocast if self.use_amp == 'native' else suppress if torchscript: self.model = torch.jit.script(self.model) @@ -147,16 +171,17 @@ class BenchmarkRunner: self.batch_size = kwargs.pop('batch_size', 256) self.example_inputs = None - self.num_warm_iter = 10 - self.num_bench_iter = 50 - self.log_freq = 10 + self.num_warm_iter = num_warm_iter + self.num_bench_iter = num_bench_iter + self.log_freq = num_bench_iter // 5 if 'cuda' in self.device: self.time_fn = partial(cuda_timestamp, device=self.device) else: self.time_fn = timestamp def _init_input(self): - self.example_inputs = torch.randn((self.batch_size,) + self.input_size, device=self.device) + self.example_inputs = torch.randn( + (self.batch_size,) + self.input_size, device=self.device, dtype=self.data_dtype) if self.channels_last: self.example_inputs = self.example_inputs.contiguous(memory_format=torch.channels_last) @@ -166,10 +191,6 @@ class InferenceBenchmarkRunner(BenchmarkRunner): def __init__(self, model_name, device='cuda', torchscript=False, **kwargs): super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) self.model.eval() - if self.use_amp == 'apex': - self.model = amp.initialize(self.model, opt_level='O1') - if self.channels_last: - self.model = self.model.to(memory_format=torch.channels_last) def run(self): def _step(): @@ -231,16 +252,11 @@ class TrainBenchmarkRunner(BenchmarkRunner): self.loss = nn.CrossEntropyLoss().to(self.device) self.target_shape = tuple() - self.optimizer = create_optimizer( + self.optimizer = create_optimizer_v2( self.model, opt_name=kwargs.pop('opt', 'sgd'), lr=kwargs.pop('lr', 1e-4)) - if self.use_amp == 'apex': - self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1') - if self.channels_last: - self.model = self.model.to(memory_format=torch.channels_last) - def _gen_target(self, batch_size): return torch.empty( (batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes) @@ -331,6 +347,7 @@ class TrainBenchmarkRunner(BenchmarkRunner): samples_per_sec=round(num_samples / t_run_elapsed, 2), step_time=round(1000 * total_step / num_samples, 3), batch_size=self.batch_size, + img_size=self.input_size[-1], param_count=round(self.param_count / 1e6, 2), ) @@ -367,23 +384,14 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs): def benchmark(args): if args.amp: - if has_native_amp: - args.native_amp = True - elif has_apex: - args.apex_amp = True - else: - _logger.warning("Neither APEX or Native Torch AMP is available.") - if args.native_amp: - args.use_amp = 'native' - _logger.info('Benchmarking in mixed precision with native PyTorch AMP.') - elif args.apex_amp: - args.use_amp = 'apex' - _logger.info('Benchmarking in mixed precision with NVIDIA APEX AMP.') - else: - args.use_amp = '' - _logger.info('Benchmarking in float32. AMP not enabled.') + _logger.warning("Overriding precision to 'amp' since --amp flag set.") + args.precision = 'amp' + _logger.info(f'Benchmarking in {args.precision} precision. ' + f'{"NHWC" if args.channels_last else "NCHW"} layout. ' + f'torchscript {"enabled" if args.torchscript else "disabled"}') bench_kwargs = vars(args).copy() + bench_kwargs.pop('amp') model = bench_kwargs.pop('model') batch_size = bench_kwargs.pop('batch_size') diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 40988946..26d8650b 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -89,7 +89,7 @@ default_cfgs = dict( regnety_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'), regnety_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'), regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'), - regnety_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth'), + regnety_160=_cfg(url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth'), regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'), ) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 578a5f08..7a7afbff 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -281,8 +281,9 @@ class VisionTransformer(nn.Module): # Classifier head(s) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ - if num_classes > 0 and distilled else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() # Weight init assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') @@ -336,8 +337,8 @@ class VisionTransformer(nn.Module): def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ - if num_classes > 0 and self.dist_token is not None else nn.Identity() + if self.head_dist is not None: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) @@ -356,8 +357,8 @@ class VisionTransformer(nn.Module): def forward(self, x): x = self.forward_features(x) - if isinstance(x, tuple): - x, x_dist = self.head(x[0]), self.head_dist(x[1]) + if self.head_dist is not None: + x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple if self.training and not torch.jit.is_scripting(): # during inference, return the average of both classifier predictions return x, x_dist diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 816bbc8e..1656559f 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -145,6 +145,12 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): return model +@register_model +def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): + # NOTE this is forwarding to model def above for backwards compatibility + return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs) + + @register_model def vit_base_r50_s16_384(pretrained=False, **kwargs): """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). @@ -157,6 +163,12 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs): return model +@register_model +def vit_base_resnet50_384(pretrained=False, **kwargs): + # NOTE this is forwarding to model def above for backwards compatibility + return vit_base_r50_s16_384(pretrained=pretrained, **kwargs) + + @register_model def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index 8bb21abb..7c4f4d36 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -10,4 +10,4 @@ from .radam import RAdam from .rmsprop_tf import RMSpropTF from .sgdp import SGDP -from .optim_factory import create_optimizer, optimizer_kwargs \ No newline at end of file +from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs \ No newline at end of file diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index c3abdb76..a4844f14 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -55,7 +55,21 @@ def optimizer_kwargs(cfg): return kwargs -def create_optimizer( +def create_optimizer(args, model, filter_bias_and_bn=True): + """ Legacy optimizer factory for backwards compatibility. + NOTE: Use create_optimizer_v2 for new code. + """ + opt_args = dict(lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) + if hasattr(args, 'opt_eps') and args.opt_eps is not None: + opt_args['eps'] = args.opt_eps + if hasattr(args, 'opt_betas') and args.opt_betas is not None: + opt_args['betas'] = args.opt_betas + if hasattr(args, 'opt_args') and args.opt_args is not None: + opt_args.update(args.opt_args) + return create_optimizer_v2(model, opt_name=args.opt, filter_bias_and_bn=filter_bias_and_bn, **opt_args) + + +def create_optimizer_v2( model: nn.Module, opt_name: str = 'sgd', lr: Optional[float] = None, diff --git a/timm/version.py b/timm/version.py index ab45471d..1e4826d6 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.4.6' +__version__ = '0.4.7' diff --git a/train.py b/train.py index 7b8e92e8..e1f308ae 100755 --- a/train.py +++ b/train.py @@ -33,7 +33,7 @@ from timm.models import create_model, safe_model_name, resume_checkpoint, load_c convert_splitbn_model, model_parameters from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy -from timm.optim import create_optimizer, optimizer_kwargs +from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler @@ -389,7 +389,7 @@ def main(): assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) - optimizer = create_optimizer(model, **optimizer_kwargs(cfg=args)) + optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing