From 4bc103f504e0d4f36d36c728aedb32966a7806cc Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 23 Feb 2021 13:15:52 -0800 Subject: [PATCH 01/24] Fix CUDA crash w/ channels-last + CSP models. Remove use of chunk() --- timm/models/cspnet.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index ca9eaf16..afd1dcd7 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -264,9 +264,11 @@ class CrossStage(nn.Module): if self.conv_down is not None: x = self.conv_down(x) x = self.conv_exp(x) - xs, xb = x.chunk(2, dim=1) + split = x.shape[1] // 2 + xs, xb = x[:, :split], x[:, split:] xb = self.blocks(xb) - out = self.conv_transition(torch.cat([xs, self.conv_transition_b(xb)], dim=1)) + xb = self.conv_transition_b(xb).contiguous() + out = self.conv_transition(torch.cat([xs, xb], dim=1)) return out From 0e16d4e9fb15a7a51cfe2d01d69c7a23dce713c3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 23 Feb 2021 13:26:42 -0800 Subject: [PATCH 02/24] Add benchmark.py script, and update optimizer factory to be more friendly to use outside of argparse interface. --- benchmark.py | 459 ++++++++++++++++++++++++++++++++++++ timm/optim/__init__.py | 2 +- timm/optim/optim_factory.py | 79 +++++-- train.py | 4 +- 4 files changed, 519 insertions(+), 25 deletions(-) create mode 100755 benchmark.py diff --git a/benchmark.py b/benchmark.py new file mode 100755 index 00000000..c745e278 --- /dev/null +++ b/benchmark.py @@ -0,0 +1,459 @@ +#!/usr/bin/env python3 +""" Model Benchmark Script + +An inference and train step benchmark script for timm models. + +Hacked together by Ross Wightman (https://github.com/rwightman) +""" +import argparse +import os +import csv +import json +import time +import logging +import torch +import torch.nn as nn +import torch.nn.parallel +from collections import OrderedDict +from contextlib import suppress +from functools import partial + +from timm.models import create_model, is_model, list_models +from timm.optim import create_optimizer +from timm.data import resolve_data_config +from timm.utils import AverageMeter, setup_default_logging + + +has_apex = False +try: + from apex import amp + has_apex = True +except ImportError: + pass + +has_native_amp = False +try: + if getattr(torch.cuda.amp, 'autocast') is not None: + has_native_amp = True +except AttributeError: + pass + +torch.backends.cudnn.benchmark = True +_logger = logging.getLogger('validate') + + +parser = argparse.ArgumentParser(description='PyTorch Benchmark') + +# benchmark specific args +parser.add_argument('--bench', default='both', type=str, + help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'inference'") +parser.add_argument('--detail', action='store_true', default=False, + help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False') +parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', + help='Output csv file for validation results (summary)') + +# common inference / train args +parser.add_argument('--model', '-m', metavar='NAME', default='resnet50', + help='model architecture (default: resnet50)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--input-size', default=None, nargs=3, type=int, + metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') +parser.add_argument('--num-classes', type=int, default=None, + help='Number classes in dataset') +parser.add_argument('--gp', default=None, type=str, metavar='POOL', + help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') +parser.add_argument('--channels-last', action='store_true', default=False, + help='Use channels_last memory layout') +parser.add_argument('--amp', action='store_true', default=False, + help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') +parser.add_argument('--apex-amp', action='store_true', default=False, + help='Use NVIDIA Apex AMP mixed precision') +parser.add_argument('--native-amp', action='store_true', default=False, + help='Use Native Torch AMP mixed precision') +parser.add_argument('--torchscript', dest='torchscript', action='store_true', + help='convert model torchscript for inference') + + +# train optimizer parameters +parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "sgd"') +parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: None, use opt default)') +parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', + help='Optimizer Betas (default: None, use opt default)') +parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='Optimizer momentum (default: 0.9)') +parser.add_argument('--weight-decay', type=float, default=0.0001, + help='weight decay (default: 0.0001)') +parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') +parser.add_argument('--clip-mode', type=str, default='norm', + help='Gradient clipping mode. One of ("norm", "value", "agc")') + + +# model regularization / loss params that impact model or loss fn +parser.add_argument('--smoothing', type=float, default=0.1, + help='Label smoothing (default: 0.1)') +parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', + help='Dropout rate (default: 0.)') +parser.add_argument('--drop-path', type=float, default=None, metavar='PCT', + help='Drop path rate (default: None)') +parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', + help='Drop block rate (default: None)') + + +def timestamp(sync=False): + return time.perf_counter() + + +def cuda_timestamp(sync=False, device=None): + if sync: + torch.cuda.synchronize(device=device) + return time.perf_counter() + + +def count_params(model): + return sum([m.numel() for m in model.parameters()]) + + +class BenchmarkRunner: + def __init__(self, model_name, detail=False, device='cuda', torchscript=False, **kwargs): + self.model_name = model_name + self.detail = detail + self.device = device + self.model = create_model( + model_name, + num_classes=kwargs.pop('num_classes', None), + in_chans=3, + global_pool=kwargs.pop('gp', 'fast'), + scriptable=torchscript).to(device=self.device) + self.num_classes = self.model.num_classes + self.param_count = count_params(self.model) + _logger.info('Model %s created, param count: %d' % (model_name, self.param_count)) + + self.channels_last = kwargs.pop('channels_last', False) + self.use_amp = kwargs.pop('use_amp', '') + self.amp_autocast = torch.cuda.amp.autocast if self.use_amp == 'native' else suppress + if torchscript: + self.model = torch.jit.script(self.model) + + data_config = resolve_data_config(kwargs, model=self.model, use_test_size=True) + self.input_size = data_config['input_size'] + self.batch_size = kwargs.pop('batch_size', 256) + + self.example_inputs = None + self.num_warm_iter = 10 + self.num_bench_iter = 50 + self.log_freq = 10 + if 'cuda' in self.device: + self.time_fn = partial(cuda_timestamp, device=self.device) + else: + self.time_fn = timestamp + + def _init_input(self): + self.example_inputs = torch.randn((self.batch_size,) + self.input_size, device=self.device) + if self.channels_last: + self.example_inputs = self.example_inputs.contiguous(memory_format=torch.channels_last) + + +class InferenceBenchmarkRunner(BenchmarkRunner): + + def __init__(self, model_name, device='cuda', torchscript=False, **kwargs): + super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) + self.model.eval() + if self.use_amp == 'apex': + self.model = amp.initialize(self.model, opt_level='O1') + if self.channels_last: + self.model = self.model.to(memory_format=torch.channels_last) + + def run(self): + def _step(): + t_step_start = self.time_fn() + with self.amp_autocast(): + output = self.model(self.example_inputs) + t_step_end = self.time_fn(True) + return t_step_end - t_step_start + + _logger.info( + f'Running inference benchmark on {self.model_name} for {self.num_bench_iter} steps w/ ' + f'input size {self.input_size} and batch size {self.batch_size}.') + + with torch.no_grad(): + self._init_input() + + for _ in range(self.num_warm_iter): + _step() + + total_step = 0. + num_samples = 0 + t_run_start = self.time_fn() + for i in range(self.num_bench_iter): + delta_fwd = _step() + total_step += delta_fwd + num_samples += self.batch_size + if (i + 1) % self.log_freq == 0: + _logger.info( + f"Infer [{i + 1}/{self.num_bench_iter}]." + f" {num_samples / total_step:0.2f} samples/sec." + f" {1000 * total_step / num_samples:0.3f} ms/sample.") + t_run_end = self.time_fn(True) + t_run_elapsed = t_run_end - t_run_start + + results = dict( + samples_per_sec=round(num_samples / t_run_elapsed, 2), + step_time=round(1000 * total_step / num_samples, 3), + batch_size=self.batch_size, + param_count=round(self.param_count / 1e6, 2), + ) + + _logger.info( + f"Inference benchmark of {self.model_name} done. " + f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/sample") + + return results + + +class TrainBenchmarkRunner(BenchmarkRunner): + + def __init__(self, model_name, device='cuda', torchscript=False, **kwargs): + super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) + self.model.train() + + if kwargs.pop('smoothing', 0) > 0: + self.loss = nn.CrossEntropyLoss().to(self.device) + else: + self.loss = nn.CrossEntropyLoss().to(self.device) + self.target_shape = tuple() + + self.optimizer = create_optimizer( + self.model, + opt_name=kwargs.pop('opt', 'sgd'), + lr=kwargs.pop('lr', 1e-4)) + + if self.use_amp == 'apex': + self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1') + if self.channels_last: + self.model = self.model.to(memory_format=torch.channels_last) + + def _gen_target(self, batch_size): + return torch.empty( + (batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes) + + def run(self): + def _step(detail=False): + self.optimizer.zero_grad() # can this be ignored? + t_start = self.time_fn() + t_fwd_end = t_start + t_bwd_end = t_start + with self.amp_autocast(): + output = self.model(self.example_inputs) + if isinstance(output, tuple): + output = output[0] + if detail: + t_fwd_end = self.time_fn(True) + target = self._gen_target(output.shape[0]) + self.loss(output, target).backward() + if detail: + t_bwd_end = self.time_fn(True) + self.optimizer.step() + t_end = self.time_fn(True) + if detail: + delta_fwd = t_fwd_end - t_start + delta_bwd = t_bwd_end - t_fwd_end + delta_opt = t_end - t_bwd_end + return delta_fwd, delta_bwd, delta_opt + else: + delta_step = t_end - t_start + return delta_step + + _logger.info( + f'Running train benchmark on {self.model_name} for {self.num_bench_iter} steps w/ ' + f'input size {self.input_size} and batch size {self.batch_size}.') + + self._init_input() + + for _ in range(self.num_warm_iter): + _step() + + t_run_start = self.time_fn() + if self.detail: + total_fwd = 0. + total_bwd = 0. + total_opt = 0. + num_samples = 0 + for i in range(self.num_bench_iter): + delta_fwd, delta_bwd, delta_opt = _step(True) + num_samples += self.batch_size + total_fwd += delta_fwd + total_bwd += delta_bwd + total_opt += delta_opt + if (i + 1) % self.log_freq == 0: + total_step = total_fwd + total_bwd + total_opt + _logger.info( + f"Train [{i + 1}/{self.num_bench_iter}]." + f" {num_samples / total_step:0.2f} samples/sec." + f" {1000 * total_fwd / num_samples:0.3f} ms/sample fwd," + f" {1000 * total_bwd / num_samples:0.3f} ms/sample bwd," + f" {1000 * total_opt / num_samples:0.3f} ms/sample opt." + ) + total_step = total_fwd + total_bwd + total_opt + t_run_elapsed = self.time_fn() - t_run_start + results = dict( + samples_per_sec=round(num_samples / t_run_elapsed, 2), + step_time=round(1000 * total_step / num_samples, 3), + fwd_time=round(1000 * total_fwd / num_samples, 3), + bwd_time=round(1000 * total_bwd / num_samples, 3), + opt_time=round(1000 * total_opt / num_samples, 3), + batch_size=self.batch_size, + param_count=round(self.param_count / 1e6, 2), + ) + else: + total_step = 0. + num_samples = 0 + for i in range(self.num_bench_iter): + delta_step = _step(False) + num_samples += self.batch_size + total_step += delta_step + if (i + 1) % self.log_freq == 0: + _logger.info( + f"Train [{i + 1}/{self.num_bench_iter}]." + f" {num_samples / total_step:0.2f} samples/sec." + f" {1000 * total_step / num_samples:0.3f} ms/sample.") + t_run_elapsed = self.time_fn() - t_run_start + results = dict( + samples_per_sec=round(num_samples / t_run_elapsed, 2), + step_time=round(1000 * total_step / num_samples, 3), + batch_size=self.batch_size, + param_count=round(self.param_count / 1e6, 2), + ) + + _logger.info( + f"Train benchmark of {self.model_name} done. " + f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/sample") + + return results + + +def decay_batch_exp(batch_size, factor=0.5, divisor=16): + out_batch_size = batch_size * factor + if out_batch_size > divisor: + out_batch_size = (out_batch_size + 1) // divisor * divisor + else: + out_batch_size = batch_size - 1 + return max(0, int(out_batch_size)) + + +def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs): + batch_size = initial_batch_size + results = dict() + while batch_size >= 1: + try: + bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs) + results = bench.run() + return results + except RuntimeError as e: + torch.cuda.empty_cache() + batch_size = decay_batch_exp(batch_size) + print(f'Reducing batch size to {batch_size}') + return results + + +def benchmark(args): + if args.amp: + if has_native_amp: + args.native_amp = True + elif has_apex: + args.apex_amp = True + else: + _logger.warning("Neither APEX or Native Torch AMP is available.") + if args.native_amp: + args.use_amp = 'native' + _logger.info('Benchmarking in mixed precision with native PyTorch AMP.') + elif args.apex_amp: + args.use_amp = 'apex' + _logger.info('Benchmarking in mixed precision with NVIDIA APEX AMP.') + else: + args.use_amp = '' + _logger.info('Benchmarking in float32. AMP not enabled.') + + bench_kwargs = vars(args).copy() + model = bench_kwargs.pop('model') + batch_size = bench_kwargs.pop('batch_size') + + bench_fns = (InferenceBenchmarkRunner,) + prefixes = ('infer',) + if args.bench == 'both': + bench_fns = ( + InferenceBenchmarkRunner, + TrainBenchmarkRunner + ) + prefixes = ('infer', 'train') + elif args.bench == 'train': + bench_fns = TrainBenchmarkRunner, + prefixes = 'train', + + model_results = OrderedDict(model=model) + for prefix, bench_fn in zip(prefixes, bench_fns): + run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs) + if prefix: + run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()} + model_results.update(run_results) + param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0)) + model_results.setdefault('param_count', param_count) + model_results.pop('train_param_count', 0) + return model_results + + +def main(): + setup_default_logging() + args = parser.parse_args() + model_cfgs = [] + model_names = [] + + if args.model == 'all': + # validate all models in a list of names with pretrained checkpoints + args.pretrained = True + model_names = list_models(pretrained=True, exclude_filters=['*in21k']) + model_cfgs = [(n, None) for n in model_names] + elif not is_model(args.model): + # model name doesn't exist, try as wildcard filter + model_names = list_models(args.model) + model_cfgs = [(n, None) for n in model_names] + + if len(model_cfgs): + results_file = args.results_file or './benchmark.csv' + _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) + results = [] + try: + for m, _ in model_cfgs: + args.model = m + r = benchmark(args) + results.append(r) + except KeyboardInterrupt as e: + pass + sort_key = 'train_samples_per_sec' if 'train' in args.bench else 'infer_samples_per_sec' + results = sorted(results, key=lambda x: x[sort_key], reverse=True) + if len(results): + write_results(results_file, results) + + import json + json_str = json.dumps(results, indent=4) + print(json_str) + else: + benchmark(args) + + +def write_results(results_file, results): + with open(results_file, mode='w') as cf: + dw = csv.DictWriter(cf, fieldnames=results[0].keys()) + dw.writeheader() + for r in results: + dw.writerow(r) + cf.flush() + + +if __name__ == '__main__': + main() diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index 33e4907f..8bb21abb 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -10,4 +10,4 @@ from .radam import RAdam from .rmsprop_tf import RMSpropTF from .sgdp import SGDP -from .optim_factory import create_optimizer \ No newline at end of file +from .optim_factory import create_optimizer, optimizer_kwargs \ No newline at end of file diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 4c0aaca0..c3abdb76 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -1,8 +1,11 @@ """ Optimizer Factory w/ Custom Weight Decay Hacked together by / Copyright 2020 Ross Wightman """ +from typing import Optional + import torch -from torch import optim as optim +import torch.nn as nn +import torch.optim as optim from .adafactor import Adafactor from .adahessian import Adahessian @@ -37,9 +40,49 @@ def add_weight_decay(model, weight_decay=1e-5, skip_list=()): {'params': decay, 'weight_decay': weight_decay}] -def create_optimizer(args, model, filter_bias_and_bn=True): - opt_lower = args.opt.lower() - weight_decay = args.weight_decay +def optimizer_kwargs(cfg): + """ cfg/argparse to kwargs helper + Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn. + """ + kwargs = dict(opt_name=cfg.opt, lr=cfg.lr, weight_decay=cfg.weight_decay) + if getattr(cfg, 'opt_eps', None) is not None: + kwargs['eps'] = cfg.opt_eps + if getattr(cfg, 'opt_betas', None) is not None: + kwargs['betas'] = cfg.opt_betas + if getattr(cfg, 'opt_args', None) is not None: + kwargs.update(cfg.opt_args) + kwargs['momentum'] = cfg.momentum + return kwargs + + +def create_optimizer( + model: nn.Module, + opt_name: str = 'sgd', + lr: Optional[float] = None, + weight_decay: float = 0., + momentum: float = 0.9, + filter_bias_and_bn: bool = True, + **kwargs): + """ Create an optimizer. + + TODO currently the model is passed in and all parameters are selected for optimization. + For more general use an interface that allows selection of parameters to optimize and lr groups, one of: + * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion + * expose the parameters interface and leave it up to caller + + Args: + model (nn.Module): model containing parameters to optimize + opt_name: name of optimizer to create + lr: initial learning rate + weight_decay: weight decay to apply in optimizer + momentum: momentum for momentum based optimizers (others may use betas via kwargs) + filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay + **kwargs: extra optimizer specific kwargs to pass through + + Returns: + Optimizer + """ + opt_lower = opt_name.lower() if weight_decay and filter_bias_and_bn: skip = {} if hasattr(model, 'no_weight_decay'): @@ -48,26 +91,18 @@ def create_optimizer(args, model, filter_bias_and_bn=True): weight_decay = 0. else: parameters = model.parameters() - if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' - opt_args = dict(lr=args.lr, weight_decay=weight_decay) - if hasattr(args, 'opt_eps') and args.opt_eps is not None: - opt_args['eps'] = args.opt_eps - if hasattr(args, 'opt_betas') and args.opt_betas is not None: - opt_args['betas'] = args.opt_betas - if hasattr(args, 'opt_args') and args.opt_args is not None: - opt_args.update(args.opt_args) - + opt_args = dict(lr=lr, weight_decay=weight_decay, **kwargs) opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': opt_args.pop('eps', None) - optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) + optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'momentum': opt_args.pop('eps', None) - optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) + optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, **opt_args) elif opt_lower == 'adamw': @@ -78,30 +113,30 @@ def create_optimizer(args, model, filter_bias_and_bn=True): optimizer = RAdam(parameters, **opt_args) elif opt_lower == 'adamp': optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) - elif opt_lower == 'sgdp': - optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) + elif opt_lower == 'sgdp': + optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, **opt_args) elif opt_lower == 'adafactor': - if not args.lr: + if not lr: opt_args['lr'] = None optimizer = Adafactor(parameters, **opt_args) elif opt_lower == 'adahessian': optimizer = Adahessian(parameters, **opt_args) elif opt_lower == 'rmsprop': - optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) + optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args) elif opt_lower == 'rmsproptf': - optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) + optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args) elif opt_lower == 'novograd': optimizer = NovoGrad(parameters, **opt_args) elif opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, **opt_args) elif opt_lower == 'fusedsgd': opt_args.pop('eps', None) - optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) + optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args) elif opt_lower == 'fusedmomentum': opt_args.pop('eps', None) - optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) + optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) elif opt_lower == 'fusedadamw': diff --git a/train.py b/train.py index 9abcfed3..9db5175b 100755 --- a/train.py +++ b/train.py @@ -32,7 +32,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy -from timm.optim import create_optimizer +from timm.optim import create_optimizer, optimizer_kwargs from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler @@ -384,7 +384,7 @@ def main(): assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) - optimizer = create_optimizer(args, model) + optimizer = create_optimizer(model, **optimizer_kwargs(cfg=args)) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing From f0ffdf89b308b192be6390657110e44cf9cecf26 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 23 Feb 2021 15:54:55 -0800 Subject: [PATCH 03/24] Add numerous experimental ViT Hybrid models w/ ResNetV2 base. Update the ViT naming for hybrids. Fix #426 for pretrained vit resizing. --- timm/models/resnetv2.py | 6 +- timm/models/vision_transformer.py | 276 ++++++++++++++++++++++++++---- 2 files changed, 247 insertions(+), 35 deletions(-) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 73c2e42c..2df02f49 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -274,7 +274,9 @@ class ResNetStage(nn.Module): return x -def create_stem(in_chs, out_chs, stem_type='', preact=True, conv_layer=None, norm_layer=None): +def create_resnetv2_stem( + in_chs, out_chs=64, stem_type='', preact=True, + conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)): stem = OrderedDict() assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same') @@ -322,7 +324,7 @@ class ResNetV2(nn.Module): self.feature_info = [] stem_chs = make_div(stem_chs * wf) - self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer) + self.stem = create_resnetv2_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer) # NOTE no, reduction 2 feature if preact self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module='' if preact else 'stem.norm')) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index acd4d18d..02c32cb7 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -28,9 +28,9 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained -from .layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_ +from .layers import StdConv2dSame, StdConv2d, DropPath, to_2tuple, trunc_normal_ from .resnet import resnet26d, resnet50d -from .resnetv2 import ResNetV2 +from .resnetv2 import ResNetV2, create_resnetv2_stem from .registry import register_model _logger = logging.getLogger(__name__) @@ -97,17 +97,62 @@ default_cfgs = { url='', # FIXME I have weights for this but > 2GB limit for github release binaries num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), - # hybrid models (weights ported from official Google JAX impl) - 'vit_base_resnet50_224_in21k': _cfg( + # hybrid in-21k models (weights ported from official Google JAX impl where they exist) + 'vit_base_r50_s16_224_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'), - 'vit_base_resnet50_384': _cfg( + + # hybrid in-1k models (weights ported from official Google JAX impl where they exist) + 'vit_small_r_s16_p8_224': _cfg( + input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r20_s16_p2_224': _cfg( + input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r20_s16_p2_384': _cfg( + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r20_s16_224': _cfg( + input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r20_s16_384': _cfg( + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r26_s32_224': _cfg( + input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r26_s32_384': _cfg( + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r20_s16_224': _cfg( + input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r20_s16_384': _cfg( + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r26_s32_224': _cfg( + input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r26_s32_384': _cfg( + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r50_s16_224': _cfg( + input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r50_s16_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_large_r50_s32_224': _cfg( + input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_large_r50_s32_384': _cfg( + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), # hybrid models (my experiments) 'vit_small_resnet26d_224': _cfg(), - 'vit_small_resnet50d_s3_224': _cfg(), + 'vit_small_resnet50d_s16_224': _cfg(), 'vit_base_resnet26d_224': _cfg(), 'vit_base_resnet50d_224': _cfg(), @@ -227,11 +272,13 @@ class HybridEmbed(nn.Module): """ CNN Feature Map Embedding Extract feature map from CNN, flatten, project to embedding dim. """ - def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) self.img_size = img_size + self.patch_size = patch_size self.backbone = backbone if feature_size is None: with torch.no_grad(): @@ -253,8 +300,9 @@ class HybridEmbed(nn.Module): feature_dim = self.backbone.feature_info.channels()[-1] else: feature_dim = self.backbone.num_features - self.num_patches = feature_size[0] * feature_size[1] - self.proj = nn.Conv2d(feature_dim, embed_dim, 1) + assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 + self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1] + self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.backbone(x) @@ -270,9 +318,10 @@ class VisionTransformer(nn.Module): A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + def __init__(self, img_size=224, patch_size=None, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None): + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, + act_layer=None): """ Args: img_size (int, tuple): input image size @@ -296,10 +345,12 @@ class VisionTransformer(nn.Module): self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + patch_size = patch_size or 1 if hybrid_backbone is not None else 16 if hybrid_backbone is not None: self.patch_embed = HybridEmbed( - hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + hybrid_backbone, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) else: self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) @@ -313,7 +364,7 @@ class VisionTransformer(nn.Module): self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) for i in range(depth)]) self.norm = norm_layer(embed_dim) @@ -423,13 +474,15 @@ class DistilledVisionTransformer(VisionTransformer): return (x + x_dist) / 2 -def resize_pos_embed(posemb, posemb_new): +def resize_pos_embed(posemb, posemb_new, token='class'): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) ntok_new = posemb_new.shape[1] - if True: - posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] + if token: + assert token in ('class', 'distill') + token_idx = 2 if token == 'distill' else 1 + posemb_tok, posemb_grid = posemb[:, :token_idx], posemb[0, token_idx:] ntok_new -= 1 else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] @@ -633,33 +686,190 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): return model +def _resnetv2(layers=(3, 4, 9), **kwargs): + """ ResNet-V2 backbone helper""" + padding_same = kwargs.get('padding_same', True) + if padding_same: + stem_type = 'same' + conv_layer = StdConv2dSame + else: + stem_type = '' + conv_layer = StdConv2d + if len(layers): + backbone = ResNetV2( + layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), + preact=False, stem_type=stem_type, conv_layer=conv_layer) + else: + backbone = create_resnetv2_stem( + kwargs.get('in_chans', 3), stem_type=stem_type, preact=False, conv_layer=conv_layer) + return backbone + + @register_model -def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): +def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ - # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head - backbone = ResNetV2( - layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), - preact=False, stem_type='same', conv_layer=StdConv2dSame) + backbone = _resnetv2(layers=(3, 4, 9), **kwargs) + model_kwargs = dict( + embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, representation_size=768, **kwargs) + model = _create_vision_transformer('vit_base_r50_s16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. + """ + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r20_s16_p2_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r_s16_p8_224(pretrained=False, **kwargs): + """ R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224. + """ + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r_s16_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r20_s16_p2_224(pretrained=False, **kwargs): + """ R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224. + """ + backbone = _resnetv2((2, 4), **kwargs) model_kwargs = dict( - embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, - representation_size=768, **kwargs) - model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs) + patch_size=2, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r20_s16_p2_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r20_s16_p2_384(pretrained=False, **kwargs): + """ R20+ViT-S/S16 w/ 2x2 Patch hybrid @ 384x384. + """ + backbone = _resnetv2((2, 4), **kwargs) + model_kwargs = dict( + embed_dim=384, patch_size=2, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r20_s16_p2_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r20_s16_224(pretrained=False, **kwargs): + """ R20+ViT-S/S16 hybrid. + """ + backbone = _resnetv2((2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r20_s16_224', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_base_resnet50_384(pretrained=False, **kwargs): +def vit_small_r20_s16_384(pretrained=False, **kwargs): + """ R20+ViT-S/S16 hybrid @ 384x384. + """ + backbone = _resnetv2((2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r20_s16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r26_s32_224(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r26_s32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r26_s32_384(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid @ 384x384. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r26_s32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r20_s16_224(pretrained=False, **kwargs): + """ R20+ViT-B/S16 hybrid. + """ + backbone = _resnetv2((2, 2, 2), **kwargs) + model_kwargs = dict( + embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, act_layer=nn.SiLU, **kwargs) + model = _create_vision_transformer('vit_base_r20_s16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r20_s16_384(pretrained=False, **kwargs): + """ R20+ViT-B/S16 hybrid. + """ + backbone = _resnetv2((2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_base_r20_s16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r26_s32_224(pretrained=False, **kwargs): + """ R26+ViT-B/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_base_r26_s32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r50_s16_224(pretrained=False, **kwargs): + """ R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929). + """ + backbone = _resnetv2((3, 4, 9), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_base_r50_s16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r50_s16_384(pretrained=False, **kwargs): """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ - # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head - backbone = ResNetV2( - layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), - preact=False, stem_type='same', conv_layer=StdConv2dSame) + backbone = _resnetv2((3, 4, 9), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_base_r50_s16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_r50_s32_224(pretrained=False, **kwargs): + """ R50+ViT-L/S32 hybrid. + """ + backbone = _resnetv2((3, 4, 6, 3), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_large_r50_s32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_r50_s32_384(pretrained=False, **kwargs): + """ R50+ViT-L/S32 hybrid. + """ + backbone = _resnetv2((3, 4, 6, 3), **kwargs) model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('vit_large_r50_s32_384', pretrained=pretrained, **model_kwargs) return model @@ -674,12 +884,12 @@ def vit_small_resnet26d_224(pretrained=False, **kwargs): @register_model -def vit_small_resnet50d_s3_224(pretrained=False, **kwargs): +def vit_small_resnet50d_s16_224(pretrained=False, **kwargs): """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights. """ backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3]) model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('vit_small_resnet50d_s16_224', pretrained=pretrained, **model_kwargs) return model From de97be9146435c896a62e017700a124183e9db0c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 23 Feb 2021 16:22:55 -0800 Subject: [PATCH 04/24] Spell out diff between my small and deit small vit models. --- timm/models/vision_transformer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 02c32cb7..f6a09ac2 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -541,7 +541,11 @@ def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwa @register_model def vit_small_patch16_224(pretrained=False, **kwargs): - """ My custom 'small' ViT model. Depth=8, heads=8= mlp_ratio=3.""" + """ My custom 'small' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3. + NOTE: + * this differs from the DeiT based 'small' definitions with embed_dim=384, depth=12, num_heads=6 + * this model does not have a bias for QKV (unlike the official ViT and DeiT models) + """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs) @@ -994,4 +998,4 @@ def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs): model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer( 'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) - return model \ No newline at end of file + return model From 2db2d87ff7e78f8326d37cace94315167d409c29 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 23 Feb 2021 17:31:42 -0800 Subject: [PATCH 05/24] Add epoch-repeats arg to multiply the number of dataset passes per epoch. Currently for iterable datasets (read TFDS wrapper) only. --- timm/data/config.py | 2 +- timm/data/dataset.py | 3 ++- timm/data/dataset_factory.py | 1 + timm/data/parsers/parser_tfds.py | 9 +++++---- train.py | 6 +++++- validate.py | 2 +- 6 files changed, 15 insertions(+), 8 deletions(-) diff --git a/timm/data/config.py b/timm/data/config.py index dad8eb13..38f5689a 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -5,7 +5,7 @@ from .constants import * _logger = logging.getLogger(__name__) -def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=True): +def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False): new_config = {} default_cfg = default_cfg if not default_cfg and model is not None and hasattr(model, 'default_cfg'): diff --git a/timm/data/dataset.py b/timm/data/dataset.py index a7c5ebed..e719f3f6 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -73,12 +73,13 @@ class IterableImageDataset(data.IterableDataset): batch_size=None, class_map='', load_bytes=False, + repeats=0, transform=None, ): assert parser is not None if isinstance(parser, str): self.parser = create_parser( - parser, root=root, split=split, is_training=is_training, batch_size=batch_size) + parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats) else: self.parser = parser self.transform = transform diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index b2c9688f..ccc99d5c 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -23,6 +23,7 @@ def create_dataset(name, root, split='validation', search_split=True, is_trainin root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs) else: # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future + kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier if search_split and os.path.isdir(root): root = _search_split(root, split) ds = ImageDataset(root, parser=name, **kwargs) diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 15361cb5..0c2e10c0 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -52,7 +52,7 @@ class ParserTfds(Parser): components. """ - def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None): + def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0): super().__init__() self.root = root self.split = split @@ -62,6 +62,7 @@ class ParserTfds(Parser): assert batch_size is not None,\ "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" self.batch_size = batch_size + self.repeats = repeats self.builder = tfds.builder(name, data_dir=root) # NOTE: please use tfds command line app to download & prepare datasets, I don't want to call @@ -126,7 +127,7 @@ class ParserTfds(Parser): # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) ds.options().experimental_threading.max_intra_op_parallelism = 1 - if self.is_training: + if self.is_training or self.repeats > 1: # to prevent excessive drop_last batch behaviour w/ IterableDatasets # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading ds = ds.repeat() # allow wrap around and break iteration manually @@ -143,7 +144,7 @@ class ParserTfds(Parser): # This adds extra samples and will slightly alter validation results. # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size # batches are produced (underlying tfds iter wraps around) - target_sample_count = math.ceil(self.num_samples / self._num_pipelines) + target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines) if self.is_training: # round up to nearest batch_size per worker-replica target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size @@ -176,7 +177,7 @@ class ParserTfds(Parser): def __len__(self): # this is just an estimate and does not factor in extra samples added to pad batches based on # complete worker & replica info (not available until init in dataloader). - return math.ceil(self.num_samples / self.dist_num_replicas) + return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas) def _filename(self, index, basename=False, absolute=False): assert False, "Not supported" # no random access to samples diff --git a/train.py b/train.py index 9db5175b..2fdf68d8 100755 --- a/train.py +++ b/train.py @@ -141,6 +141,8 @@ parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 2)') +parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N', + help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') parser.add_argument('--start-epoch', default=None, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', @@ -450,7 +452,9 @@ def main(): # create the train and eval datasets dataset_train = create_dataset( - args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size) + args.dataset, + root=args.data_dir, split=args.train_split, is_training=True, + batch_size=args.batch_size, repeats=args.epoch_repeats) dataset_eval = create_dataset( args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) diff --git a/validate.py b/validate.py index a311112d..6df71aab 100755 --- a/validate.py +++ b/validate.py @@ -152,7 +152,7 @@ def validate(args): param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) - data_config = resolve_data_config(vars(args), model=model, use_test_size=True) + data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True) test_time_pool = False if not args.no_test_pool: model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) From 0706d05d525e67753137a38057b513d793701fe0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 28 Feb 2021 16:00:33 -0800 Subject: [PATCH 06/24] Benchmark models listed in txt file. Add more hybrid vit variants for testing --- benchmark.py | 13 +- timm/models/vision_transformer.py | 261 ++++++++++++++++++++++++++++-- 2 files changed, 261 insertions(+), 13 deletions(-) diff --git a/benchmark.py b/benchmark.py index c745e278..faac1fc5 100755 --- a/benchmark.py +++ b/benchmark.py @@ -45,6 +45,8 @@ _logger = logging.getLogger('validate') parser = argparse.ArgumentParser(description='PyTorch Benchmark') # benchmark specific args +parser.add_argument('--model-list', metavar='NAME', default='', + help='txt file based list of model names to benchmark') parser.add_argument('--bench', default='both', type=str, help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'inference'") parser.add_argument('--detail', action='store_true', default=False, @@ -357,7 +359,7 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs): except RuntimeError as e: torch.cuda.empty_cache() batch_size = decay_batch_exp(batch_size) - print(f'Reducing batch size to {batch_size}') + print(f'Error: {str(e)} while running benchmark. Reducing batch size to {batch_size} for retry.') return results @@ -413,7 +415,12 @@ def main(): model_cfgs = [] model_names = [] - if args.model == 'all': + if args.model_list: + args.model = '' + with open(args.model_list) as f: + model_names = [line.rstrip() for line in f] + model_cfgs = [(n, None) for n in model_names] + elif args.model == 'all': # validate all models in a list of names with pretrained checkpoints args.pretrained = True model_names = list_models(pretrained=True, exclude_filters=['*in21k']) @@ -429,6 +436,8 @@ def main(): results = [] try: for m, _ in model_cfgs: + if not m: + continue args.model = m r = benchmark(args) results.append(r) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index f6a09ac2..f834d8e1 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -103,48 +103,90 @@ default_cfgs = { num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'), # hybrid in-1k models (weights ported from official Google JAX impl where they exist) + 'vit_tiny_r_s16_p8_224': _cfg( + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_tiny_r_s16_p8_224_in21k': _cfg( + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_tiny_r_s16_p8_384': _cfg( + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r_s16_p8_224': _cfg( - input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r_s16_p8_224_in21k': _cfg( + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r_s16_p8_384': _cfg( + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r20_s16_p2_224': _cfg( - input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r20_s16_p2_224_in21k': _cfg( + inum_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), 'vit_small_r20_s16_p2_384': _cfg( input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), + + 'vit_small_r20_s16_224': _cfg( - input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r20_s16_224_in21k': _cfg( + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), 'vit_small_r20_s16_384': _cfg( input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r26_s32_224': _cfg( - input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r26_s32_224_in21k': _cfg( + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), 'vit_small_r26_s32_384': _cfg( input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r20_s16_224': _cfg( - input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r20_s16_224_in21k': _cfg( + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), 'vit_base_r20_s16_384': _cfg( input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r26_s32_224': _cfg( - input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r26_s32_224_in21k': _cfg( + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), 'vit_base_r26_s32_384': _cfg( input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r50_s16_224': _cfg( - input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), 'vit_base_r50_s16_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), + 'vit_large_r50_s32_224': _cfg( - input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_large_r50_s32_224_in21k': _cfg( + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), 'vit_large_r50_s32_384': _cfg( input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, @@ -159,8 +201,19 @@ default_cfgs = { # deit models (FB weights) 'vit_deit_tiny_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), + 'vit_deit_tiny_patch16_224_in21k': _cfg(num_classes=21843), + 'vit_deit_tiny_patch16_224_in21k_norep': _cfg(num_classes=21843), + 'vit_deit_tiny_patch16_384': _cfg(input_size=(3, 384, 384)), + 'vit_deit_small_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), + 'vit_deit_small_patch16_224_in21k': _cfg(num_classes=21843), + 'vit_deit_small_patch16_384': _cfg(input_size=(3, 384, 384)), + + 'vit_deit_small_patch32_224': _cfg(), + 'vit_deit_small_patch32_224_in21k': _cfg(num_classes=21843), + 'vit_deit_small_patch32_384': _cfg(input_size=(3, 384, 384)), + 'vit_deit_base_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',), 'vit_deit_base_patch16_384': _cfg( @@ -728,7 +781,29 @@ def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): backbone = _resnetv2(layers=(), **kwargs) model_kwargs = dict( patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r20_s16_p2_224', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('vit_tiny_r_s16_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. + """ + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=12, num_heads=3, representation_size=192, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_tiny_r_s16_p8_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. + """ + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_tiny_r_s16_p8_384', pretrained=pretrained, **model_kwargs) return model @@ -740,6 +815,29 @@ def vit_small_r_s16_p8_224(pretrained=False, **kwargs): model_kwargs = dict( patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) model = _create_vision_transformer('vit_small_r_s16_p8_224', pretrained=pretrained, **model_kwargs) + + return model + + +@register_model +def vit_small_r_s16_p8_224_in21k(pretrained=False, **kwargs): + """ R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224. + """ + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=12, num_heads=6, representation_size=384, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r_s16_p8_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r_s16_p8_384(pretrained=False, **kwargs): + """ R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224. + """ + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r_s16_p8_384', pretrained=pretrained, **model_kwargs) return model @@ -754,6 +852,17 @@ def vit_small_r20_s16_p2_224(pretrained=False, **kwargs): return model +@register_model +def vit_small_r20_s16_p2_224_in21k(pretrained=False, **kwargs): + """ R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224. + """ + backbone = _resnetv2((2, 4), **kwargs) + model_kwargs = dict( + patch_size=2, embed_dim=384, depth=12, num_heads=6, representation_size=384, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r20_s16_p2_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_small_r20_s16_p2_384(pretrained=False, **kwargs): """ R20+ViT-S/S16 w/ 2x2 Patch hybrid @ 384x384. @@ -775,6 +884,16 @@ def vit_small_r20_s16_224(pretrained=False, **kwargs): return model +@register_model +def vit_small_r20_s16_224_in21k(pretrained=False, **kwargs): + """ R20+ViT-S/S16 hybrid. + """ + backbone = _resnetv2((2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, representation_size=384, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r20_s16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_small_r20_s16_384(pretrained=False, **kwargs): """ R20+ViT-S/S16 hybrid @ 384x384. @@ -795,6 +914,17 @@ def vit_small_r26_s32_224(pretrained=False, **kwargs): return model +@register_model +def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict( + embed_dim=384, depth=12, num_heads=6, representation_size=384, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r26_s32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_small_r26_s32_384(pretrained=False, **kwargs): """ R26+ViT-S/S32 hybrid @ 384x384. @@ -810,12 +940,22 @@ def vit_base_r20_s16_224(pretrained=False, **kwargs): """ R20+ViT-B/S16 hybrid. """ backbone = _resnetv2((2, 2, 2), **kwargs) - model_kwargs = dict( - embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, act_layer=nn.SiLU, **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) model = _create_vision_transformer('vit_base_r20_s16_224', pretrained=pretrained, **model_kwargs) return model +@register_model +def vit_base_r20_s16_224_in21k(pretrained=False, **kwargs): + """ R20+ViT-B/S16 hybrid. + """ + backbone = _resnetv2((2, 2, 2), **kwargs) + model_kwargs = dict( + embed_dim=768, depth=12, num_heads=12, representation_size=768, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_base_r20_s16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_base_r20_s16_384(pretrained=False, **kwargs): """ R20+ViT-B/S16 hybrid. @@ -836,6 +976,27 @@ def vit_base_r26_s32_224(pretrained=False, **kwargs): return model +@register_model +def vit_base_r26_s32_224_in21k(pretrained=False, **kwargs): + """ R26+ViT-B/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict( + embed_dim=768, depth=12, num_heads=12, representation_size=768, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_base_r26_s32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r26_s32_384(pretrained=False, **kwargs): + """ R26+ViT-B/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_base_r26_s32_384', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_base_r50_s16_224(pretrained=False, **kwargs): """ R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929). @@ -867,6 +1028,17 @@ def vit_large_r50_s32_224(pretrained=False, **kwargs): return model +@register_model +def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs): + """ R50+ViT-L/S32 hybrid. + """ + backbone = _resnetv2((3, 4, 6, 3), **kwargs) + model_kwargs = dict( + embed_dim=768, depth=12, num_heads=12, representation_size=768, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_large_r50_s32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_large_r50_s32_384(pretrained=False, **kwargs): """ R50+ViT-L/S32 hybrid. @@ -927,6 +1099,31 @@ def vit_deit_tiny_patch16_224(pretrained=False, **kwargs): return model +@register_model +def vit_deit_tiny_patch16_224_in21k_norep(pretrained=False, **kwargs): + """ DeiT-tiny model""" + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_deit_tiny_patch16_224_in21k_norep', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_deit_tiny_patch16_224_in21k(pretrained=False, **kwargs): + """ DeiT-tiny model""" + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, representation_size=192, **kwargs) + model = _create_vision_transformer('vit_deit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_deit_tiny_patch16_384(pretrained=False, **kwargs): + """ DeiT-tiny model""" + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_deit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + + @register_model def vit_deit_small_patch16_224(pretrained=False, **kwargs): """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). @@ -937,6 +1134,48 @@ def vit_deit_small_patch16_224(pretrained=False, **kwargs): return model +@register_model +def vit_deit_small_patch16_224_in21k(pretrained=False, **kwargs): + """ DeiT-small """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, representation_size=384, **kwargs) + model = _create_vision_transformer('vit_deit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_deit_small_patch16_384(pretrained=False, **kwargs): + """ DeiT-small """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_deit_small_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_deit_small_patch32_224(pretrained=False, **kwargs): + """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_deit_small_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_deit_small_patch32_224_in21k(pretrained=False, **kwargs): + """ DeiT-small """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, representation_size=384, **kwargs) + model = _create_vision_transformer('vit_deit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_deit_small_patch32_384(pretrained=False, **kwargs): + """ DeiT-small """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_deit_small_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_deit_base_patch16_224(pretrained=False, **kwargs): """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). From 17cdee7354d70ab8874583cd3f2868840aeb8f05 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 1 Mar 2021 16:53:32 -0800 Subject: [PATCH 07/24] Fix C&P patch_size error, and order of op patch_size arg resolution bug. Remove a test vit model. --- timm/models/vision_transformer.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index f834d8e1..aed295ec 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -202,7 +202,6 @@ default_cfgs = { 'vit_deit_tiny_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), 'vit_deit_tiny_patch16_224_in21k': _cfg(num_classes=21843), - 'vit_deit_tiny_patch16_224_in21k_norep': _cfg(num_classes=21843), 'vit_deit_tiny_patch16_384': _cfg(input_size=(3, 384, 384)), 'vit_deit_small_patch16_224': _cfg( @@ -399,7 +398,7 @@ class VisionTransformer(nn.Module): self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU - patch_size = patch_size or 1 if hybrid_backbone is not None else 16 + patch_size = patch_size or (1 if hybrid_backbone is not None else 16) if hybrid_backbone is not None: self.patch_embed = HybridEmbed( @@ -1099,14 +1098,6 @@ def vit_deit_tiny_patch16_224(pretrained=False, **kwargs): return model -@register_model -def vit_deit_tiny_patch16_224_in21k_norep(pretrained=False, **kwargs): - """ DeiT-tiny model""" - model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) - model = _create_vision_transformer('vit_deit_tiny_patch16_224_in21k_norep', pretrained=pretrained, **model_kwargs) - return model - - @register_model def vit_deit_tiny_patch16_224_in21k(pretrained=False, **kwargs): """ DeiT-tiny model""" @@ -1155,7 +1146,7 @@ def vit_deit_small_patch32_224(pretrained=False, **kwargs): """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_deit_small_patch32_224', pretrained=pretrained, **model_kwargs) return model @@ -1163,7 +1154,7 @@ def vit_deit_small_patch32_224(pretrained=False, **kwargs): @register_model def vit_deit_small_patch32_224_in21k(pretrained=False, **kwargs): """ DeiT-small """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, representation_size=384, **kwargs) + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, representation_size=384, **kwargs) model = _create_vision_transformer('vit_deit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -1171,7 +1162,7 @@ def vit_deit_small_patch32_224_in21k(pretrained=False, **kwargs): @register_model def vit_deit_small_patch32_384(pretrained=False, **kwargs): """ DeiT-small """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_deit_small_patch32_384', pretrained=pretrained, **model_kwargs) return model From 4445eaa470ff321413b114c258de1be64d7156a6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 5 Mar 2021 16:48:31 -0800 Subject: [PATCH 08/24] Add img_size to benchmark output --- benchmark.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/benchmark.py b/benchmark.py index faac1fc5..e692eacc 100755 --- a/benchmark.py +++ b/benchmark.py @@ -208,6 +208,7 @@ class InferenceBenchmarkRunner(BenchmarkRunner): samples_per_sec=round(num_samples / t_run_elapsed, 2), step_time=round(1000 * total_step / num_samples, 3), batch_size=self.batch_size, + img_size=self.input_size[-1], param_count=round(self.param_count / 1e6, 2), ) @@ -310,6 +311,7 @@ class TrainBenchmarkRunner(BenchmarkRunner): bwd_time=round(1000 * total_bwd / num_samples, 3), opt_time=round(1000 * total_opt / num_samples, 3), batch_size=self.batch_size, + img_size=self.input_size[-1], param_count=round(self.param_count / 1e6, 2), ) else: From 4de57ccf0123650bf759960d9ac64dca6263da7c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Mar 2021 15:35:22 -0700 Subject: [PATCH 09/24] Add weight init scheme that's closer to JAX impl --- timm/models/vision_transformer.py | 74 +++++++++++++++++++++++++------ 1 file changed, 61 insertions(+), 13 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index aed295ec..5fb5c7c7 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -373,7 +373,7 @@ class VisionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=None, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, - act_layer=None): + act_layer=None, weight_init=''): """ Args: img_size (int, tuple): input image size @@ -434,17 +434,13 @@ class VisionTransformer(nn.Module): self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() trunc_normal_(self.pos_embed, std=.02) - trunc_normal_(self.cls_token, std=.02) - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) + if weight_init != 'jax': # leave as zeros to match JAX impl + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if weight_init == 'jax': + _init_weights_jax(m, n) + else: + _init_weights_original(m, n) @torch.jit.ignore def no_weight_decay(self): @@ -479,6 +475,58 @@ class VisionTransformer(nn.Module): return x +def _init_weights_original(m: nn.Module, n: str = ''): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.zeros_(m.bias) + nn.init.ones_(m.weight) + + +def _init_weights_jax(m: nn.Module, n: str): + """ Weight init scheme closer to the official JAX impl than my original init""" + + def _fan_in(tensor): + dimensions = tensor.dim() + if dimensions < 2: + raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") + + num_input_fmaps = tensor.size(1) + receptive_field_size = 1 + if tensor.dim() > 2: + receptive_field_size = tensor[0][0].numel() + fan_in = num_input_fmaps * receptive_field_size + return fan_in + + def _lecun_normal(w): + stddev = (1.0 / _fan_in(w)) ** 0.5 / .87962566103423978 + trunc_normal_(w, 0, stddev) + + if isinstance(m, nn.Linear): + if 'head' in n: + nn.init.zeros_(m.weight) + nn.init.zeros_(m.bias) + elif 'pre_logits' in n: + _lecun_normal(m.weight) + nn.init.zeros_(m.bias) + else: + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + if 'mlp' in n: + nn.init.normal_(m.bias, 0, 1e-6) + else: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + _lecun_normal(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0.) + nn.init.constant_(m.weight, 1.) + + class DistilledVisionTransformer(VisionTransformer): """ Vision Transformer with distillation token. @@ -496,7 +544,7 @@ class DistilledVisionTransformer(VisionTransformer): trunc_normal_(self.dist_token, std=.02) trunc_normal_(self.pos_embed, std=.02) - self.head_dist.apply(self._init_weights) + self.head_dist.apply(_init_weights_original) def forward_features(self, x): B = x.shape[0] From cbcb76d72c74cab6d0ab12915e1cf851605c6f59 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Mar 2021 23:15:48 -0700 Subject: [PATCH 10/24] Should have included Conv2d layers in original weight init. Lets see what the impact is... --- timm/models/vision_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 5fb5c7c7..42943fab 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -476,7 +476,7 @@ class VisionTransformer(nn.Module): def _init_weights_original(m: nn.Module, n: str = ''): - if isinstance(m, nn.Linear): + if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) From f42f1df26c5b44321a9cc65aca3f728a89d7479d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Mar 2021 23:16:14 -0700 Subject: [PATCH 11/24] Improve evenness of per-worker split for validation set with TFDS --- timm/data/parsers/parser_tfds.py | 41 +++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 0c2e10c0..0b12a4db 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -29,6 +29,11 @@ SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue PREFETCH_SIZE = 4096 # samples to prefetch +def even_split_indices(split, n, num_samples): + partitions = [round(i * num_samples / n) for i in range(n + 1)] + return [f"{split}[{partitions[i]}:{partitions[i+1]}]" for i in range(n)] + + class ParserTfds(Parser): """ Wrap Tensorflow Datasets for use in PyTorch @@ -63,6 +68,7 @@ class ParserTfds(Parser): "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" self.batch_size = batch_size self.repeats = repeats + self.subsplit = None self.builder = tfds.builder(name, data_dir=root) # NOTE: please use tfds command line app to download & prepare datasets, I don't want to call @@ -96,6 +102,7 @@ class ParserTfds(Parser): if worker_info is not None: self.worker_info = worker_info num_workers = worker_info.num_workers + global_num_workers = self.dist_num_replicas * num_workers worker_id = worker_info.id # FIXME I need to spend more time figuring out the best way to distribute/split data across @@ -115,15 +122,27 @@ class ParserTfds(Parser): # split = split + '[{}:]'.format(start) # else: # split = split + '[{}:{}]'.format(start, start + split_size) - - input_context = tf.distribute.InputContext( - num_input_pipelines=self.dist_num_replicas * num_workers, - input_pipeline_id=self.dist_rank * num_workers + worker_id, - num_replicas_in_sync=self.dist_num_replicas # FIXME does this have any impact? - ) - - read_config = tfds.ReadConfig(input_context=input_context) - ds = self.builder.as_dataset(split=split, shuffle_files=self.shuffle, read_config=read_config) + if not self.is_training and '[' not in self.split: + # If not training, and split doesn't define a subsplit, manually split the dataset + # for more even samples / worker + self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[ + self.dist_rank * num_workers + worker_id] + + if self.subsplit is None: + input_context = tf.distribute.InputContext( + num_input_pipelines=self.dist_num_replicas * num_workers, + input_pipeline_id=self.dist_rank * num_workers + worker_id, + num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? + ) + else: + input_context = None + + read_config = tfds.ReadConfig( + shuffle_seed=42, + shuffle_reshuffle_each_iteration=True, + input_context=input_context) + ds = self.builder.as_dataset( + split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) ds.options().experimental_threading.max_intra_op_parallelism = 1 @@ -161,8 +180,8 @@ class ParserTfds(Parser): if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count: # Validation batch padding only done for distributed training where results are reduced across nodes. # For single process case, it won't matter if workers return different batch sizes. - # FIXME this needs more testing, possible for sharding / split api to cause differences of > 1? - assert target_sample_count - sample_count == 1 # should only be off by 1 or sharding is not optimal + # FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this + # approach is not optimal yield img, sample['label'] # yield prev sample again sample_count += 1 From cf5fec504754ecd56b2d4307b521a1d7d2eeaa8a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 20 Mar 2021 09:44:24 -0700 Subject: [PATCH 12/24] Cleanup experimental vit weight init a bit --- timm/models/layers/__init__.py | 2 +- timm/models/layers/weight_init.py | 29 +++++++++++++ timm/models/vision_transformer.py | 69 +++++++++++++++++-------------- 3 files changed, 68 insertions(+), 32 deletions(-) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index f8d8d8c0..89fb859c 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -31,4 +31,4 @@ from .split_attn import SplitAttnConv2d from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .test_time_pool import TestTimePoolHead, apply_test_time_pool -from .weight_init import trunc_normal_ +from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ diff --git a/timm/models/layers/weight_init.py b/timm/models/layers/weight_init.py index d731029f..305a2fd0 100644 --- a/timm/models/layers/weight_init.py +++ b/timm/models/layers/weight_init.py @@ -2,6 +2,8 @@ import torch import math import warnings +from torch.nn.init import _calculate_fan_in_and_fan_out + def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW @@ -58,3 +60,30 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): >>> nn.init.trunc_normal_(w) """ return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == 'fan_in': + denom = fan_in + elif mode == 'fan_out': + denom = fan_out + elif mode == 'fan_avg': + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) + elif distribution == "normal": + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 42943fab..45c1eddb 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -28,7 +28,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained -from .layers import StdConv2dSame, StdConv2d, DropPath, to_2tuple, trunc_normal_ +from .layers import StdConv2dSame, StdConv2d, DropPath, to_2tuple, trunc_normal_, lecun_normal_ from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2, create_resnetv2_stem from .registry import register_model @@ -373,7 +373,7 @@ class VisionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=None, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, - act_layer=None, weight_init=''): + act_layer=None, weight_init='new_nlhb'): """ Args: img_size (int, tuple): input image size @@ -433,14 +433,20 @@ class VisionTransformer(nn.Module): # Classifier head self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self._init_weights(weight_init) + + def _init_weights(self, weight_init: str): trunc_normal_(self.pos_embed, std=.02) - if weight_init != 'jax': # leave as zeros to match JAX impl + if weight_init.startswith('jax'): + init_fn = _init_weights_jax + # leave cls token as zeros to match jax impl + else: trunc_normal_(self.cls_token, std=.02) + init_fn = _init_weights_new if weight_init.startswith('new') else _init_weights_old + hb = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. + init_fn = partial(init_fn, head_bias=hb) for n, m in self.named_modules(): - if weight_init == 'jax': - _init_weights_jax(m, n) - else: - _init_weights_original(m, n) + init_fn(m, n) @torch.jit.ignore def no_weight_decay(self): @@ -475,41 +481,42 @@ class VisionTransformer(nn.Module): return x -def _init_weights_original(m: nn.Module, n: str = ''): - if isinstance(m, (nn.Conv2d, nn.Linear)): +def _init_weights_old(m: nn.Module, n: str = '', head_bias: float = 0.): + if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) + if m.bias is not None: + if 'head' in n: + nn.init.constant_(m.bias, head_bias) + else: + nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.zeros_(m.bias) nn.init.ones_(m.weight) -def _init_weights_jax(m: nn.Module, n: str): - """ Weight init scheme closer to the official JAX impl than my original init""" - - def _fan_in(tensor): - dimensions = tensor.dim() - if dimensions < 2: - raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") +def _init_weights_new(m: nn.Module, n: str = '', head_bias: float = 0.): + if isinstance(m, (nn.Conv2d, nn.Linear)): + #trunc_normal_(m.weight, std=.02) + lecun_normal_(m.weight) + if m.bias is not None: + if 'head' in n: + nn.init.constant_(m.bias, head_bias) + else: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.zeros_(m.bias) + nn.init.ones_(m.weight) - num_input_fmaps = tensor.size(1) - receptive_field_size = 1 - if tensor.dim() > 2: - receptive_field_size = tensor[0][0].numel() - fan_in = num_input_fmaps * receptive_field_size - return fan_in - def _lecun_normal(w): - stddev = (1.0 / _fan_in(w)) ** 0.5 / .87962566103423978 - trunc_normal_(w, 0, stddev) +def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.): + """ Attempt at weight init scheme closer to the official JAX impl than my original init""" if isinstance(m, nn.Linear): if 'head' in n: nn.init.zeros_(m.weight) - nn.init.zeros_(m.bias) + nn.init.constant_(m.bias, head_bias) elif 'pre_logits' in n: - _lecun_normal(m.weight) + lecun_normal_(m.weight) nn.init.zeros_(m.bias) else: nn.init.xavier_uniform_(m.weight) @@ -519,7 +526,7 @@ def _init_weights_jax(m: nn.Module, n: str): else: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): - _lecun_normal(m.weight) + lecun_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): @@ -544,7 +551,7 @@ class DistilledVisionTransformer(VisionTransformer): trunc_normal_(self.dist_token, std=.02) trunc_normal_(self.pos_embed, std=.02) - self.head_dist.apply(_init_weights_original) + self.head_dist.apply(_init_weights_new) def forward_features(self, x): B = x.shape[0] From e2e3290fbfdd6dc9b6f4f8ba206317faac20d956 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 20 Mar 2021 12:02:17 -0700 Subject: [PATCH 13/24] Add '--experiment' to train args for fixed exp name if desired, 'train' not added to output folder if specified. --- train.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index 2fdf68d8..d0886c08 100755 --- a/train.py +++ b/train.py @@ -259,6 +259,8 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--output', default='', type=str, metavar='PATH', help='path to output folder (default: none, current dir)') +parser.add_argument('--experiment', default='', type=str, metavar='NAME', + help='name of train experiment, name of sub-folder for output') parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', help='Best metric (default: "top1"') parser.add_argument('--tta', type=int, default=0, metavar='N', @@ -544,13 +546,15 @@ def main(): saver = None output_dir = '' if args.local_rank == 0: - output_base = args.output if args.output else './output' - exp_name = '-'.join([ - datetime.now().strftime("%Y%m%d-%H%M%S"), - args.model, - str(data_config['input_size'][-1]) - ]) - output_dir = get_outdir(output_base, 'train', exp_name) + if args.experiment: + exp_name = args.experiment + else: + exp_name = '-'.join([ + datetime.now().strftime("%Y%m%d-%H%M%S"), + args.model, + str(data_config['input_size'][-1]) + ]) + output_dir = get_outdir(args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver( model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, From 0dfc5a66bb5253bb0b473cce3ea741768f5ae5eb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 31 Mar 2021 18:20:14 -0700 Subject: [PATCH 14/24] Add PiT model from https://github.com/naver-ai/pit --- tests/test_models.py | 2 +- timm/models/__init__.py | 1 + timm/models/pit.py | 385 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 387 insertions(+), 1 deletion(-) create mode 100644 timm/models/pit.py diff --git a/tests/test_models.py b/tests/test_models.py index 1f70d115..639a0534 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -14,7 +14,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): torch._C._jit_set_profiling_mode(False) # transformer models don't support many of the spatial / feature based model functionalities -NON_STD_FILTERS = ['vit_*', 'tnt_*'] +NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures diff --git a/timm/models/__init__.py b/timm/models/__init__.py index ab3c4b2f..d810909c 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -14,6 +14,7 @@ from .inception_v4 import * from .mobilenetv3 import * from .nasnet import * from .nfnet import * +from .pit import * from .pnasnet import * from .regnet import * from .res2net import * diff --git a/timm/models/pit.py b/timm/models/pit.py new file mode 100644 index 00000000..2137bea8 --- /dev/null +++ b/timm/models/pit.py @@ -0,0 +1,385 @@ +""" Pooling-based Vision Transformer (PiT) in PyTorch + +A PyTorch implement of Pooling-based Vision Transformers as described in +'Rethinking Spatial Dimensions of Vision Transformers' - https://arxiv.org/abs/2103.16302 + +This code was adapted from the original version at https://github.com/naver-ai/pit, original copyright below. + +Modifications for timm by / Copyright 2020 Ross Wightman +""" +# PiT +# Copyright 2021-present NAVER Corp. +# Apache License v2.0 + +import math +import re +from copy import deepcopy +from functools import partial +from typing import Tuple + +import torch +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import trunc_normal_, to_2tuple +from .registry import register_model +from .vision_transformer import Block + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.conv', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # deit models (FB weights) + 'pit_ti_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_730.pth'), + 'pit_xs_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_781.pth'), + 'pit_s_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_809.pth'), + 'pit_b_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_820.pth'), + 'pit_ti_distilled_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth'), + 'pit_xs_distilled_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth'), + 'pit_s_distilled_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth'), + 'pit_b_distilled_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth'), + +} + + +class SequentialTuple(nn.Sequential): + """ This module exists to work around torchscript typing issues list -> list""" + def __init__(self, *args): + super(SequentialTuple, self).__init__(*args) + + def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + for module in self: + x = module(x) + return x + + +class Transformer(nn.Module): + def __init__( + self, base_dim, depth, heads, mlp_ratio, pool=None, drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None): + super(Transformer, self).__init__() + self.layers = nn.ModuleList([]) + embed_dim = base_dim * heads + + self.blocks = nn.Sequential(*[ + Block( + dim=embed_dim, + num_heads=heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=drop_path_prob[i], + norm_layer=partial(nn.LayerNorm, eps=1e-6) + ) + for i in range(depth)]) + + self.pool = pool + + def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + x, cls_tokens = x + B, C, H, W = x.shape + token_length = cls_tokens.shape[1] + + x = x.flatten(2).transpose(1, 2) + x = torch.cat((cls_tokens, x), dim=1) + + x = self.blocks(x) + + cls_tokens = x[:, :token_length] + x = x[:, token_length:] + x = x.transpose(1, 2).reshape(B, C, H, W) + + if self.pool is not None: + x, cls_tokens = self.pool(x, cls_tokens) + return x, cls_tokens + + +class ConvHeadPooling(nn.Module): + def __init__(self, in_feature, out_feature, stride, padding_mode='zeros'): + super(ConvHeadPooling, self).__init__() + + self.conv = nn.Conv2d( + in_feature, out_feature, kernel_size=stride + 1, padding=stride // 2, stride=stride, + padding_mode=padding_mode, groups=in_feature) + self.fc = nn.Linear(in_feature, out_feature) + + def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]: + + x = self.conv(x) + cls_token = self.fc(cls_token) + + return x, cls_token + + +class ConvEmbedding(nn.Module): + def __init__(self, in_channels, out_channels, patch_size, stride, padding): + super(ConvEmbedding, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size=patch_size, stride=stride, padding=padding, bias=True) + + def forward(self, x): + x = self.conv(x) + return x + + +class PoolingVisionTransformer(nn.Module): + """ Pooling-based Vision Transformer + + A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers' + - https://arxiv.org/abs/2103.16302 + """ + def __init__(self, img_size, patch_size, stride, base_dims, depth, heads, + mlp_ratio, num_classes=1000, in_chans=3, distilled=False, + attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0): + super(PoolingVisionTransformer, self).__init__() + + padding = 0 + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + height = math.floor((img_size[0] + 2 * padding - patch_size[0]) / stride + 1) + width = math.floor((img_size[1] + 2 * padding - patch_size[1]) / stride + 1) + + self.base_dims = base_dims + self.heads = heads + self.num_classes = num_classes + self.num_tokens = 2 if distilled else 1 + + self.patch_size = patch_size + self.pos_embed = nn.Parameter(torch.randn(1, base_dims[0] * heads[0], height, width)) + self.patch_embed = ConvEmbedding(in_chans, base_dims[0] * heads[0], patch_size, stride, padding) + + self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, base_dims[0] * heads[0])) + self.pos_drop = nn.Dropout(p=drop_rate) + + transformers = [] + # stochastic depth decay rule + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depth)).split(depth)] + for stage in range(len(depth)): + pool = None + if stage < len(heads) - 1: + pool = ConvHeadPooling( + base_dims[stage] * heads[stage], base_dims[stage + 1] * heads[stage + 1], stride=2) + transformers += [Transformer( + base_dims[stage], depth[stage], heads[stage], mlp_ratio, pool=pool, + drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_prob=dpr[stage]) + ] + self.transformers = SequentialTuple(*transformers) + self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6) + self.embed_dim = base_dims[-1] * heads[-1] + + # Classifier head + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ + if num_classes > 0 and distilled else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ + if num_classes > 0 and self.num_tokens == 2 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + x = self.pos_drop(x + self.pos_embed) + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) + x, cls_tokens = self.transformers((x, cls_tokens)) + cls_tokens = self.norm(cls_tokens) + return cls_tokens + + def forward(self, x): + x = self.forward_features(x) + x_cls = self.head(x[:, 0]) + if self.num_tokens > 1: + x_dist = self.head_dist(x[:, 1]) + if self.training and not torch.jit.is_scripting(): + return x_cls, x_dist + else: + return (x_cls + x_dist) / 2 + else: + return x_cls + + +def checkpoint_filter_fn(state_dict, model): + """ preprocess checkpoints """ + out_dict = {} + p_blocks = re.compile(r'pools\.(\d)\.') + for k, v in state_dict.items(): + # FIXME need to update resize for PiT impl + # if k == 'pos_embed' and v.shape != model.pos_embed.shape: + # # To resize pos embedding when using model at different size from pretrained weights + # v = resize_pos_embed(v, model.pos_embed) + k = p_blocks.sub(lambda exp: f'transformers.{int(exp.group(1))}.pool.', k) + out_dict[k] = v + return out_dict + + +def _create_pit(variant, pretrained=False, **kwargs): + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-2:] + img_size = kwargs.pop('img_size', default_img_size) + num_classes = kwargs.pop('num_classes', default_num_classes) + + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + PoolingVisionTransformer, variant, pretrained, + default_cfg=default_cfg, + img_size=img_size, + num_classes=num_classes, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + + return model + + +@register_model +def pit_b_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=14, + stride=7, + base_dims=[64, 64, 64], + depth=[3, 6, 4], + heads=[4, 8, 16], + mlp_ratio=4, + **kwargs + ) + return _create_pit('pit_b_224', pretrained, **model_kwargs) + + +@register_model +def pit_s_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[48, 48, 48], + depth=[2, 6, 4], + heads=[3, 6, 12], + mlp_ratio=4, + **kwargs + ) + return _create_pit('pit_s_224', pretrained, **model_kwargs) + + +@register_model +def pit_xs_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[48, 48, 48], + depth=[2, 6, 4], + heads=[2, 4, 8], + mlp_ratio=4, + **kwargs + ) + return _create_pit('pit_xs_224', pretrained, **model_kwargs) + + +@register_model +def pit_ti_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[32, 32, 32], + depth=[2, 6, 4], + heads=[2, 4, 8], + mlp_ratio=4, + **kwargs + ) + return _create_pit('pit_ti_224', pretrained, **model_kwargs) + + +@register_model +def pit_b_distilled_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=14, + stride=7, + base_dims=[64, 64, 64], + depth=[3, 6, 4], + heads=[4, 8, 16], + mlp_ratio=4, + distilled=True, + **kwargs + ) + return _create_pit('pit_b_distilled_224', pretrained, **model_kwargs) + + +@register_model +def pit_s_distilled_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[48, 48, 48], + depth=[2, 6, 4], + heads=[3, 6, 12], + mlp_ratio=4, + distilled=True, + **kwargs + ) + return _create_pit('pit_s_distilled_224', pretrained, **model_kwargs) + + +@register_model +def pit_xs_distilled_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[48, 48, 48], + depth=[2, 6, 4], + heads=[2, 4, 8], + mlp_ratio=4, + distilled=True, + **kwargs + ) + return _create_pit('pit_xs_distilled_224', pretrained, **model_kwargs) + + +@register_model +def pit_ti_distilled_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[32, 32, 32], + depth=[2, 6, 4], + heads=[2, 4, 8], + mlp_ratio=4, + distilled=True, + **kwargs + ) + return _create_pit('pit_ti_distilled_224', pretrained, **model_kwargs) \ No newline at end of file From a760a4c3f4f010856d6914d5f31a65e0e18adc66 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 31 Mar 2021 18:21:02 -0700 Subject: [PATCH 15/24] Some ViT cleanup, merge distilled model with main, fixup torchscript support for distilled models --- timm/models/vision_transformer.py | 134 +++++++++++++----------------- 1 file changed, 56 insertions(+), 78 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 82d4ee49..c7c9027d 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -268,12 +268,16 @@ class HybridEmbed(nn.Module): class VisionTransformer(nn.Module): """ Vision Transformer - A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - - https://arxiv.org/abs/2010.11929 + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None): + num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, + weight_init=''): """ Args: img_size (int, tuple): input image size @@ -287,11 +291,13 @@ class VisionTransformer(nn.Module): qkv_bias (bool): enable bias for qkv if True qk_scale (float): override default qk scale of head_dim ** -0.5 if set representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + distilled (bool): model includes a distillation token and head as in DeiT models drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module norm_layer: (nn.Module): normalization layer + weight_init: (str): weight init scheme """ super().__init__() self.num_classes = num_classes @@ -307,11 +313,13 @@ class VisionTransformer(nn.Module): num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) if distilled else None + num_tokens = 2 if distilled else 1 + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + num_tokens, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - self.blocks = nn.ModuleList([ + self.blocks = nn.Sequential(*[ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) @@ -319,7 +327,7 @@ class VisionTransformer(nn.Module): self.norm = norm_layer(embed_dim) # Representation layer - if representation_size: + if representation_size and not distilled: self.num_features = representation_size self.pre_logits = nn.Sequential(OrderedDict([ ('fc', nn.Linear(embed_dim, representation_size)), @@ -328,11 +336,15 @@ class VisionTransformer(nn.Module): else: self.pre_logits = nn.Identity() - # Classifier head + # Classifier head(s) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ + if num_classes > 0 and distilled else nn.Identity() trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) + if self.dist_token is not None: + trunc_normal_(self.dist_token, std=.02) self.apply(self._init_weights) def _init_weights(self, m): @@ -346,91 +358,58 @@ class VisionTransformer(nn.Module): @torch.jit.ignore def no_weight_decay(self): - return {'pos_embed', 'cls_token'} + return {'pos_embed', 'cls_token', 'dist_token'} def get_classifier(self): - return self.head + if self.dist_token is None: + return self.head + else: + return self.head, self.head_dist def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ + if num_classes > 0 and self.dist_token is not None else nn.Identity() def forward_features(self, x): - B = x.shape[0] - x = self.patch_embed(x) - - cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks - x = torch.cat((cls_tokens, x), dim=1) - x = x + self.pos_embed - x = self.pos_drop(x) - - for blk in self.blocks: - x = blk(x) - - x = self.norm(x)[:, 0] - x = self.pre_logits(x) - return x - - def forward(self, x): - x = self.forward_features(x) - x = self.head(x) - return x - - -class DistilledVisionTransformer(VisionTransformer): - """ Vision Transformer with distillation token. - - Paper: `Training data-efficient image transformers & distillation through attention` - - https://arxiv.org/abs/2012.12877 - - This impl of distilled ViT is taken from https://github.com/facebookresearch/deit - """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) - num_patches = self.patch_embed.num_patches - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() - - trunc_normal_(self.dist_token, std=.02) - trunc_normal_(self.pos_embed, std=.02) - self.head_dist.apply(self._init_weights) - - def forward_features(self, x): - B = x.shape[0] x = self.patch_embed(x) - - cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks - dist_token = self.dist_token.expand(B, -1, -1) - x = torch.cat((cls_tokens, dist_token, x), dim=1) - - x = x + self.pos_embed - x = self.pos_drop(x) - - for blk in self.blocks: - x = blk(x) - + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + if self.dist_token is None: + x = torch.cat((cls_token, x), dim=1) + else: + x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.pos_drop(x + self.pos_embed) + x = self.blocks(x) x = self.norm(x) - return x[:, 0], x[:, 1] + if self.dist_token is not None: + return x[:, 0], x[:, 1] + else: + return self.pre_logits(x[:, 0]) def forward(self, x): - x, x_dist = self.forward_features(x) - x = self.head(x) - x_dist = self.head_dist(x_dist) - if self.training: - return x, x_dist + x = self.forward_features(x) + if isinstance(x, tuple): + x, x_dist = self.head(x[0]), self.head_dist(x[1]) + if self.training and not torch.jit.is_scripting(): + # during inference, return the average of both classifier predictions + return x, x_dist + else: + return (x + x_dist) / 2 else: - # during inference, return the average of both classifier predictions - return (x + x_dist) / 2 + x = self.head(x) + return x -def resize_pos_embed(posemb, posemb_new): +def resize_pos_embed(posemb, posemb_new, token='class'): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) ntok_new = posemb_new.shape[1] - if True: - posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] + if token: + assert token in ('class', 'distill') + token_idx = 2 if token == 'distill' else 1 + posemb_tok, posemb_grid = posemb[:, :token_idx], posemb[0, token_idx:] ntok_new -= 1 else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] @@ -457,12 +436,12 @@ def checkpoint_filter_fn(state_dict, model): v = v.reshape(O, -1, H, W) elif k == 'pos_embed' and v.shape != model.pos_embed.shape: # To resize pos embedding when using model at different size from pretrained weights - v = resize_pos_embed(v, model.pos_embed) + v = resize_pos_embed(v, model.pos_embed, token='distill' if model.dist_token is not None else 'class') out_dict[k] = v return out_dict -def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs): +def _create_vision_transformer(variant, pretrained=False, **kwargs): default_cfg = deepcopy(default_cfgs[variant]) overlay_external_default_cfg(default_cfg, kwargs) default_num_classes = default_cfg['num_classes'] @@ -480,9 +459,8 @@ def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwa if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') - model_cls = DistilledVisionTransformer if distilled else VisionTransformer model = build_model_with_cfg( - model_cls, variant, pretrained, + VisionTransformer, variant, pretrained, default_cfg=default_cfg, img_size=img_size, num_classes=num_classes, From 7953e5d11af1dbef49fd60d9aeaba8c1d740096c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 31 Mar 2021 23:11:28 -0700 Subject: [PATCH 16/24] Fix pos_embed scaling for ViT and num_classes != 1000 for pretrained distilled deit and pit models. Fix #426 and fix #433 --- timm/models/helpers.py | 24 ++++++++++++---------- timm/models/pit.py | 13 +++++++----- timm/models/vision_transformer.py | 33 ++++++++++++++++--------------- 3 files changed, 39 insertions(+), 31 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 2f6e098b..e9ac7f00 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -198,20 +198,24 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte _logger.warning( f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') - classifier_name = default_cfg.get('classifier', None) + classifiers = default_cfg.get('classifier', None) label_offset = default_cfg.get('label_offset', 0) - if classifier_name is not None: + if classifiers is not None: + if isinstance(classifiers, str): + classifiers = (classifiers,) if num_classes != default_cfg['num_classes']: - # completely discard fully connected if model num_classes doesn't match pretrained weights - del state_dict[classifier_name + '.weight'] - del state_dict[classifier_name + '.bias'] + for classifier_name in classifiers: + # completely discard fully connected if model num_classes doesn't match pretrained weights + del state_dict[classifier_name + '.weight'] + del state_dict[classifier_name + '.bias'] strict = False elif label_offset > 0: - # special case for pretrained weights with an extra background class in pretrained weights - classifier_weight = state_dict[classifier_name + '.weight'] - state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] - classifier_bias = state_dict[classifier_name + '.bias'] - state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] + for classifier_name in classifiers: + # special case for pretrained weights with an extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] model.load_state_dict(state_dict, strict=strict) diff --git a/timm/models/pit.py b/timm/models/pit.py index 2137bea8..1cee4d04 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -49,14 +49,17 @@ default_cfgs = { 'pit_b_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_820.pth'), 'pit_ti_distilled_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth', + classifier=('head', 'head_dist')), 'pit_xs_distilled_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth', + classifier=('head', 'head_dist')), 'pit_s_distilled_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth', + classifier=('head', 'head_dist')), 'pit_b_distilled_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth'), - + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth', + classifier=('head', 'head_dist')), } diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index c7c9027d..c05871b8 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -123,14 +123,17 @@ default_cfgs = { url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', input_size=(3, 384, 384), crop_pct=1.0), 'vit_deit_tiny_distilled_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'), + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', + classifier=('head', 'head_dist')), 'vit_deit_small_distilled_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'), + url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', + classifier=('head', 'head_dist')), 'vit_deit_base_distilled_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ), + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', + classifier=('head', 'head_dist')), 'vit_deit_base_distilled_patch16_384': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', - input_size=(3, 384, 384), crop_pct=1.0), + input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')), } @@ -302,6 +305,7 @@ class VisionTransformer(nn.Module): super().__init__() self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 2 if distilled else 1 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) if hybrid_backbone is not None: @@ -313,9 +317,8 @@ class VisionTransformer(nn.Module): num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) if distilled else None - num_tokens = 2 if distilled else 1 - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + num_tokens, embed_dim)) + self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule @@ -382,10 +385,10 @@ class VisionTransformer(nn.Module): x = self.pos_drop(x + self.pos_embed) x = self.blocks(x) x = self.norm(x) - if self.dist_token is not None: - return x[:, 0], x[:, 1] - else: + if self.dist_token is None: return self.pre_logits(x[:, 0]) + else: + return x[:, 0], x[:, 1] def forward(self, x): x = self.forward_features(x) @@ -401,15 +404,13 @@ class VisionTransformer(nn.Module): return x -def resize_pos_embed(posemb, posemb_new, token='class'): +def resize_pos_embed(posemb, posemb_new, num_tokens=1): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) ntok_new = posemb_new.shape[1] - if token: - assert token in ('class', 'distill') - token_idx = 2 if token == 'distill' else 1 - posemb_tok, posemb_grid = posemb[:, :token_idx], posemb[0, token_idx:] + if num_tokens: + posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] ntok_new -= 1 else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] @@ -436,7 +437,7 @@ def checkpoint_filter_fn(state_dict, model): v = v.reshape(O, -1, H, W) elif k == 'pos_embed' and v.shape != model.pos_embed.shape: # To resize pos embedding when using model at different size from pretrained weights - v = resize_pos_embed(v, model.pos_embed, token='distill' if model.dist_token is not None else 'class') + v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1)) out_dict[k] = v return out_dict From ea9c9550b24dfaf30fdcca960b9cc24a65c359fe Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 1 Apr 2021 14:17:38 -0700 Subject: [PATCH 17/24] Fully move ViT hybrids to their own file, including embedding module. Remove some extra DeiT models that were for benchmarking only. --- timm/models/vision_transformer.py | 134 +----------- timm/models/vision_transformer_hybrid.py | 256 +++++++++-------------- 2 files changed, 112 insertions(+), 278 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 5f244589..578a5f08 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -5,6 +5,9 @@ A PyTorch implement of Vision Transformers as described in The official jax code is released and available at https://github.com/google-research/vision_transformer +DeiT model defs and weights from https://github.com/facebookresearch/deit, +paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + Acknowledgments: * The paper authors for releasing code and weights, thanks! * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out @@ -12,9 +15,6 @@ for some einops/einsum fun * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT * Bert reference code checks against Huggingface Transformers and Tensorflow Bert -DeiT model defs and weights from https://github.com/facebookresearch/deit, -paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 - Hacked together by / Copyright 2020 Ross Wightman """ import math @@ -99,18 +99,8 @@ default_cfgs = { # deit models (FB weights) 'vit_deit_tiny_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), - 'vit_deit_tiny_patch16_224_in21k': _cfg(num_classes=21843), - 'vit_deit_tiny_patch16_384': _cfg(input_size=(3, 384, 384)), - 'vit_deit_small_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), - 'vit_deit_small_patch16_224_in21k': _cfg(num_classes=21843), - 'vit_deit_small_patch16_384': _cfg(input_size=(3, 384, 384)), - - 'vit_deit_small_patch32_224': _cfg(), - 'vit_deit_small_patch32_224_in21k': _cfg(num_classes=21843), - 'vit_deit_small_patch32_384': _cfg(input_size=(3, 384, 384)), - 'vit_deit_base_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',), 'vit_deit_base_patch16_384': _cfg( @@ -220,48 +210,6 @@ class PatchEmbed(nn.Module): return x -class HybridEmbed(nn.Module): - """ CNN Feature Map Embedding - Extract feature map from CNN, flatten, project to embedding dim. - """ - def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768): - super().__init__() - assert isinstance(backbone, nn.Module) - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.backbone = backbone - if feature_size is None: - with torch.no_grad(): - # NOTE Most reliable way of determining output dims is to run forward pass - training = backbone.training - if training: - backbone.eval() - o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) - if isinstance(o, (list, tuple)): - o = o[-1] # last feature if backbone outputs list/tuple of features - feature_size = o.shape[-2:] - feature_dim = o.shape[1] - backbone.train(training) - else: - feature_size = to_2tuple(feature_size) - if hasattr(self.backbone, 'feature_info'): - feature_dim = self.backbone.feature_info.channels()[-1] - else: - feature_dim = self.backbone.num_features - assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 - self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1] - self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) - - def forward(self, x): - x = self.backbone(x) - if isinstance(x, (list, tuple)): - x = x[-1] # last feature if backbone outputs list/tuple of features - x = self.proj(x).flatten(2).transpose(1, 2) - return x - - class VisionTransformer(nn.Module): """ Vision Transformer @@ -274,7 +222,7 @@ class VisionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, act_layer=None, weight_init=''): """ Args: @@ -293,7 +241,7 @@ class VisionTransformer(nn.Module): drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate - hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module + embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer weight_init: (str): weight init scheme """ @@ -303,14 +251,9 @@ class VisionTransformer(nn.Module): self.num_tokens = 2 if distilled else 1 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU - patch_size = patch_size or (1 if hybrid_backbone is not None else 16) - if hybrid_backbone is not None: - self.patch_embed = HybridEmbed( - hybrid_backbone, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) - else: - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) @@ -489,8 +432,9 @@ def checkpoint_filter_fn(state_dict, model): return out_dict -def _create_vision_transformer(variant, pretrained=False, **kwargs): - default_cfg = deepcopy(default_cfgs[variant]) +def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): + if default_cfg is None: + default_cfg = deepcopy(default_cfgs[variant]) overlay_external_default_cfg(default_cfg, kwargs) default_num_classes = default_cfg['num_classes'] default_img_size = default_cfg['input_size'][-2:] @@ -680,22 +624,6 @@ def vit_deit_tiny_patch16_224(pretrained=False, **kwargs): return model -@register_model -def vit_deit_tiny_patch16_224_in21k(pretrained=False, **kwargs): - """ DeiT-tiny model""" - model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, representation_size=192, **kwargs) - model = _create_vision_transformer('vit_deit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_deit_tiny_patch16_384(pretrained=False, **kwargs): - """ DeiT-tiny model""" - model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) - model = _create_vision_transformer('vit_deit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) - return model - - @register_model def vit_deit_small_patch16_224(pretrained=False, **kwargs): """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). @@ -706,48 +634,6 @@ def vit_deit_small_patch16_224(pretrained=False, **kwargs): return model -@register_model -def vit_deit_small_patch16_224_in21k(pretrained=False, **kwargs): - """ DeiT-small """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, representation_size=384, **kwargs) - model = _create_vision_transformer('vit_deit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_deit_small_patch16_384(pretrained=False, **kwargs): - """ DeiT-small """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_deit_small_patch16_384', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_deit_small_patch32_224(pretrained=False, **kwargs): - """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). - ImageNet-1k weights from https://github.com/facebookresearch/deit. - """ - model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_deit_small_patch32_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_deit_small_patch32_224_in21k(pretrained=False, **kwargs): - """ DeiT-small """ - model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, representation_size=384, **kwargs) - model = _create_vision_transformer('vit_deit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_deit_small_patch32_384(pretrained=False, **kwargs): - """ DeiT-small """ - model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_deit_small_patch32_384', pretrained=pretrained, **model_kwargs) - return model - - @register_model def vit_deit_base_patch16_224(pretrained=False, **kwargs): """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 293dd34d..816bbc8e 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -9,6 +9,12 @@ keep file sizes sane. Hacked together by / Copyright 2020 Ross Wightman """ +from copy import deepcopy +from functools import partial + +import torch +import torch.nn as nn + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .layers import StdConv2dSame, StdConv2d, to_2tuple from .resnet import resnet26d, resnet50d @@ -41,39 +47,14 @@ default_cfgs = { # hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones) 'vit_tiny_r_s16_p8_224': _cfg(), - 'vit_tiny_r_s16_p8_384': _cfg( - input_size=(3, 384, 384), crop_pct=1.0), - - 'vit_small_r_s16_p8_224': _cfg( - crop_pct=1.0), - 'vit_small_r_s16_p8_384': _cfg( - input_size=(3, 384, 384), crop_pct=1.0), - + 'vit_small_r_s16_p8_224': _cfg(), 'vit_small_r20_s16_p2_224': _cfg(), - 'vit_small_r20_s16_p2_384': _cfg( - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_small_r20_s16_224': _cfg(), - 'vit_small_r20_s16_384': _cfg( - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_small_r26_s32_224': _cfg(), - 'vit_small_r26_s32_384': _cfg( - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_r20_s16_224': _cfg(), - 'vit_base_r20_s16_384': _cfg( - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_r26_s32_224': _cfg(), - 'vit_base_r26_s32_384': _cfg( - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_r50_s16_224': _cfg(), - 'vit_large_r50_s32_224': _cfg(), - 'vit_large_r50_s32_384': _cfg( - input_size=(3, 384, 384), crop_pct=1.0), # hybrid models (using timm resnet backbones) 'vit_small_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), @@ -83,6 +64,56 @@ default_cfgs = { } +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # NOTE Most reliable way of determining output dims is to run forward pass + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) + if isinstance(o, (list, tuple)): + o = o[-1] # last feature if backbone outputs list/tuple of features + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + if hasattr(self.backbone, 'feature_info'): + feature_dim = self.backbone.feature_info.channels()[-1] + else: + feature_dim = self.backbone.num_features + assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 + self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1] + self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = self.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): + default_cfg = deepcopy(default_cfgs[variant]) + embed_layer = partial(HybridEmbed, backbone=backbone) + kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set + return _create_vision_transformer( + variant, pretrained=pretrained, default_cfg=default_cfg, embed_layer=embed_layer, **kwargs) + + def _resnetv2(layers=(3, 4, 9), **kwargs): """ ResNet-V2 backbone helper""" padding_same = kwargs.get('padding_same', True) @@ -108,9 +139,9 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ backbone = _resnetv2(layers=(3, 4, 9), **kwargs) - model_kwargs = dict( - embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, representation_size=768, **kwargs) - model = _create_vision_transformer('vit_base_r50_s16_224_in21k', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -120,8 +151,9 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs): ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ backbone = _resnetv2((3, 4, 9), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_r50_s16_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -130,20 +162,9 @@ def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. """ backbone = _resnetv2(layers=(), **kwargs) - model_kwargs = dict( - patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_tiny_r_s16_p8_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs): - """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. - """ - backbone = _resnetv2(layers=(), **kwargs) - model_kwargs = dict( - patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_tiny_r_s16_p8_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -152,21 +173,10 @@ def vit_small_r_s16_p8_224(pretrained=False, **kwargs): """ R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224. """ backbone = _resnetv2(layers=(), **kwargs) - model_kwargs = dict( - patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r_s16_p8_224', pretrained=pretrained, **model_kwargs) - - return model - + model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs) -@register_model -def vit_small_r_s16_p8_384(pretrained=False, **kwargs): - """ R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224. - """ - backbone = _resnetv2(layers=(), **kwargs) - model_kwargs = dict( - patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r_s16_p8_384', pretrained=pretrained, **model_kwargs) return model @@ -175,20 +185,9 @@ def vit_small_r20_s16_p2_224(pretrained=False, **kwargs): """ R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224. """ backbone = _resnetv2((2, 4), **kwargs) - model_kwargs = dict( - patch_size=2, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r20_s16_p2_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_small_r20_s16_p2_384(pretrained=False, **kwargs): - """ R20+ViT-S/S16 w/ 2x2 Patch hybrid @ 384x384. - """ - backbone = _resnetv2((2, 4), **kwargs) - model_kwargs = dict( - embed_dim=384, patch_size=2, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r20_s16_p2_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=2, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r20_s16_p2_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -197,18 +196,9 @@ def vit_small_r20_s16_224(pretrained=False, **kwargs): """ R20+ViT-S/S16 hybrid. """ backbone = _resnetv2((2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r20_s16_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_small_r20_s16_384(pretrained=False, **kwargs): - """ R20+ViT-S/S16 hybrid @ 384x384. - """ - backbone = _resnetv2((2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r20_s16_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -217,18 +207,9 @@ def vit_small_r26_s32_224(pretrained=False, **kwargs): """ R26+ViT-S/S32 hybrid. """ backbone = _resnetv2((2, 2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r26_s32_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_small_r26_s32_384(pretrained=False, **kwargs): - """ R26+ViT-S/S32 hybrid @ 384x384. - """ - backbone = _resnetv2((2, 2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r26_s32_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -237,18 +218,9 @@ def vit_base_r20_s16_224(pretrained=False, **kwargs): """ R20+ViT-B/S16 hybrid. """ backbone = _resnetv2((2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_r20_s16_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_base_r20_s16_384(pretrained=False, **kwargs): - """ R20+ViT-B/S16 hybrid. - """ - backbone = _resnetv2((2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_r20_s16_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -257,18 +229,9 @@ def vit_base_r26_s32_224(pretrained=False, **kwargs): """ R26+ViT-B/S32 hybrid. """ backbone = _resnetv2((2, 2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_r26_s32_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_base_r26_s32_384(pretrained=False, **kwargs): - """ R26+ViT-B/S32 hybrid. - """ - backbone = _resnetv2((2, 2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_r26_s32_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -277,8 +240,9 @@ def vit_base_r50_s16_224(pretrained=False, **kwargs): """ R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929). """ backbone = _resnetv2((3, 4, 9), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_r50_s16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r50_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -287,29 +251,9 @@ def vit_large_r50_s32_224(pretrained=False, **kwargs): """ R50+ViT-L/S32 hybrid. """ backbone = _resnetv2((3, 4, 6, 3), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_large_r50_s32_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs): - """ R50+ViT-L/S32 hybrid. - """ - backbone = _resnetv2((3, 4, 6, 3), **kwargs) - model_kwargs = dict( - embed_dim=768, depth=12, num_heads=12, representation_size=768, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_large_r50_s32_224_in21k', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_large_r50_s32_384(pretrained=False, **kwargs): - """ R50+ViT-L/S32 hybrid. - """ - backbone = _resnetv2((3, 4, 6, 3), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_large_r50_s32_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -318,8 +262,9 @@ def vit_small_resnet26d_224(pretrained=False, **kwargs): """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights. """ backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) - model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -328,8 +273,9 @@ def vit_small_resnet50d_s16_224(pretrained=False, **kwargs): """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights. """ backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3]) - model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_resnet50d_s16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_resnet50d_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -338,8 +284,9 @@ def vit_base_resnet26d_224(pretrained=False, **kwargs): """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights. """ backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -348,6 +295,7 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs): """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights. """ backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model \ No newline at end of file From 288682796f4a18760f1276dd2d184f297fdf5182 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 1 Apr 2021 16:40:12 -0700 Subject: [PATCH 18/24] Update benchmark script to add precision arg. Fix some downstream (DeiT) compat issues with latest changes. Bump version to 0.4.7 --- benchmark.py | 92 +++++++++++++----------- timm/models/regnet.py | 2 +- timm/models/vision_transformer.py | 13 ++-- timm/models/vision_transformer_hybrid.py | 12 ++++ timm/optim/__init__.py | 2 +- timm/optim/optim_factory.py | 16 ++++- timm/version.py | 2 +- train.py | 4 +- 8 files changed, 89 insertions(+), 54 deletions(-) diff --git a/benchmark.py b/benchmark.py index e692eacc..5f296c24 100755 --- a/benchmark.py +++ b/benchmark.py @@ -19,7 +19,7 @@ from contextlib import suppress from functools import partial from timm.models import create_model, is_model, list_models -from timm.optim import create_optimizer +from timm.optim import create_optimizer_v2 from timm.data import resolve_data_config from timm.utils import AverageMeter, setup_default_logging @@ -53,6 +53,10 @@ parser.add_argument('--detail', action='store_true', default=False, help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False') parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', help='Output csv file for validation results (summary)') +parser.add_argument('--num-warm-iter', default=10, type=int, + metavar='N', help='Number of warmup iterations (default: 10)') +parser.add_argument('--num-bench-iter', default=40, type=int, + metavar='N', help='Number of benchmark iterations (default: 40)') # common inference / train args parser.add_argument('--model', '-m', metavar='NAME', default='resnet50', @@ -70,11 +74,9 @@ parser.add_argument('--gp', default=None, type=str, metavar='POOL', parser.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') parser.add_argument('--amp', action='store_true', default=False, - help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') -parser.add_argument('--apex-amp', action='store_true', default=False, - help='Use NVIDIA Apex AMP mixed precision') -parser.add_argument('--native-amp', action='store_true', default=False, - help='Use Native Torch AMP mixed precision') + help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.') +parser.add_argument('--precision', default='float32', type=str, + help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') @@ -117,28 +119,50 @@ def cuda_timestamp(sync=False, device=None): return time.perf_counter() -def count_params(model): +def count_params(model: nn.Module): return sum([m.numel() for m in model.parameters()]) +def resolve_precision(precision: str): + assert precision in ('amp', 'float16', 'bfloat16', 'float32') + use_amp = False + model_dtype = torch.float32 + data_dtype = torch.float32 + if precision == 'amp': + use_amp = True + elif precision == 'float16': + model_dtype = torch.float16 + data_dtype = torch.float16 + elif precision == 'bfloat16': + model_dtype = torch.bfloat16 + data_dtype = torch.bfloat16 + return use_amp, model_dtype, data_dtype + + class BenchmarkRunner: - def __init__(self, model_name, detail=False, device='cuda', torchscript=False, **kwargs): + def __init__( + self, model_name, detail=False, device='cuda', torchscript=False, precision='float32', + num_warm_iter=10, num_bench_iter=50, **kwargs): self.model_name = model_name self.detail = detail self.device = device + self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision) + self.channels_last = kwargs.pop('channels_last', False) + self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress + self.model = create_model( model_name, num_classes=kwargs.pop('num_classes', None), in_chans=3, global_pool=kwargs.pop('gp', 'fast'), - scriptable=torchscript).to(device=self.device) + scriptable=torchscript) + self.model.to( + device=self.device, + dtype=self.model_dtype, + memory_format=torch.channels_last if self.channels_last else None) self.num_classes = self.model.num_classes self.param_count = count_params(self.model) _logger.info('Model %s created, param count: %d' % (model_name, self.param_count)) - - self.channels_last = kwargs.pop('channels_last', False) - self.use_amp = kwargs.pop('use_amp', '') - self.amp_autocast = torch.cuda.amp.autocast if self.use_amp == 'native' else suppress if torchscript: self.model = torch.jit.script(self.model) @@ -147,16 +171,17 @@ class BenchmarkRunner: self.batch_size = kwargs.pop('batch_size', 256) self.example_inputs = None - self.num_warm_iter = 10 - self.num_bench_iter = 50 - self.log_freq = 10 + self.num_warm_iter = num_warm_iter + self.num_bench_iter = num_bench_iter + self.log_freq = num_bench_iter // 5 if 'cuda' in self.device: self.time_fn = partial(cuda_timestamp, device=self.device) else: self.time_fn = timestamp def _init_input(self): - self.example_inputs = torch.randn((self.batch_size,) + self.input_size, device=self.device) + self.example_inputs = torch.randn( + (self.batch_size,) + self.input_size, device=self.device, dtype=self.data_dtype) if self.channels_last: self.example_inputs = self.example_inputs.contiguous(memory_format=torch.channels_last) @@ -166,10 +191,6 @@ class InferenceBenchmarkRunner(BenchmarkRunner): def __init__(self, model_name, device='cuda', torchscript=False, **kwargs): super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) self.model.eval() - if self.use_amp == 'apex': - self.model = amp.initialize(self.model, opt_level='O1') - if self.channels_last: - self.model = self.model.to(memory_format=torch.channels_last) def run(self): def _step(): @@ -231,16 +252,11 @@ class TrainBenchmarkRunner(BenchmarkRunner): self.loss = nn.CrossEntropyLoss().to(self.device) self.target_shape = tuple() - self.optimizer = create_optimizer( + self.optimizer = create_optimizer_v2( self.model, opt_name=kwargs.pop('opt', 'sgd'), lr=kwargs.pop('lr', 1e-4)) - if self.use_amp == 'apex': - self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1') - if self.channels_last: - self.model = self.model.to(memory_format=torch.channels_last) - def _gen_target(self, batch_size): return torch.empty( (batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes) @@ -331,6 +347,7 @@ class TrainBenchmarkRunner(BenchmarkRunner): samples_per_sec=round(num_samples / t_run_elapsed, 2), step_time=round(1000 * total_step / num_samples, 3), batch_size=self.batch_size, + img_size=self.input_size[-1], param_count=round(self.param_count / 1e6, 2), ) @@ -367,23 +384,14 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs): def benchmark(args): if args.amp: - if has_native_amp: - args.native_amp = True - elif has_apex: - args.apex_amp = True - else: - _logger.warning("Neither APEX or Native Torch AMP is available.") - if args.native_amp: - args.use_amp = 'native' - _logger.info('Benchmarking in mixed precision with native PyTorch AMP.') - elif args.apex_amp: - args.use_amp = 'apex' - _logger.info('Benchmarking in mixed precision with NVIDIA APEX AMP.') - else: - args.use_amp = '' - _logger.info('Benchmarking in float32. AMP not enabled.') + _logger.warning("Overriding precision to 'amp' since --amp flag set.") + args.precision = 'amp' + _logger.info(f'Benchmarking in {args.precision} precision. ' + f'{"NHWC" if args.channels_last else "NCHW"} layout. ' + f'torchscript {"enabled" if args.torchscript else "disabled"}') bench_kwargs = vars(args).copy() + bench_kwargs.pop('amp') model = bench_kwargs.pop('model') batch_size = bench_kwargs.pop('batch_size') diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 40988946..26d8650b 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -89,7 +89,7 @@ default_cfgs = dict( regnety_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'), regnety_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'), regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'), - regnety_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth'), + regnety_160=_cfg(url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth'), regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'), ) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 578a5f08..7a7afbff 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -281,8 +281,9 @@ class VisionTransformer(nn.Module): # Classifier head(s) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ - if num_classes > 0 and distilled else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() # Weight init assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') @@ -336,8 +337,8 @@ class VisionTransformer(nn.Module): def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ - if num_classes > 0 and self.dist_token is not None else nn.Identity() + if self.head_dist is not None: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) @@ -356,8 +357,8 @@ class VisionTransformer(nn.Module): def forward(self, x): x = self.forward_features(x) - if isinstance(x, tuple): - x, x_dist = self.head(x[0]), self.head_dist(x[1]) + if self.head_dist is not None: + x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple if self.training and not torch.jit.is_scripting(): # during inference, return the average of both classifier predictions return x, x_dist diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 816bbc8e..1656559f 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -145,6 +145,12 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): return model +@register_model +def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): + # NOTE this is forwarding to model def above for backwards compatibility + return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs) + + @register_model def vit_base_r50_s16_384(pretrained=False, **kwargs): """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). @@ -157,6 +163,12 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs): return model +@register_model +def vit_base_resnet50_384(pretrained=False, **kwargs): + # NOTE this is forwarding to model def above for backwards compatibility + return vit_base_r50_s16_384(pretrained=pretrained, **kwargs) + + @register_model def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index 8bb21abb..7c4f4d36 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -10,4 +10,4 @@ from .radam import RAdam from .rmsprop_tf import RMSpropTF from .sgdp import SGDP -from .optim_factory import create_optimizer, optimizer_kwargs \ No newline at end of file +from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs \ No newline at end of file diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index c3abdb76..a4844f14 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -55,7 +55,21 @@ def optimizer_kwargs(cfg): return kwargs -def create_optimizer( +def create_optimizer(args, model, filter_bias_and_bn=True): + """ Legacy optimizer factory for backwards compatibility. + NOTE: Use create_optimizer_v2 for new code. + """ + opt_args = dict(lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) + if hasattr(args, 'opt_eps') and args.opt_eps is not None: + opt_args['eps'] = args.opt_eps + if hasattr(args, 'opt_betas') and args.opt_betas is not None: + opt_args['betas'] = args.opt_betas + if hasattr(args, 'opt_args') and args.opt_args is not None: + opt_args.update(args.opt_args) + return create_optimizer_v2(model, opt_name=args.opt, filter_bias_and_bn=filter_bias_and_bn, **opt_args) + + +def create_optimizer_v2( model: nn.Module, opt_name: str = 'sgd', lr: Optional[float] = None, diff --git a/timm/version.py b/timm/version.py index ab45471d..1e4826d6 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.4.6' +__version__ = '0.4.7' diff --git a/train.py b/train.py index 7b8e92e8..e1f308ae 100755 --- a/train.py +++ b/train.py @@ -33,7 +33,7 @@ from timm.models import create_model, safe_model_name, resume_checkpoint, load_c convert_splitbn_model, model_parameters from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy -from timm.optim import create_optimizer, optimizer_kwargs +from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler @@ -389,7 +389,7 @@ def main(): assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) - optimizer = create_optimizer(model, **optimizer_kwargs(cfg=args)) + optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing From c468c47a9cecf9a3cb872dc7e24d203086b0f3b2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 1 Apr 2021 16:41:04 -0700 Subject: [PATCH 19/24] Add regnety_160 weights from DeiT teacher model, update that and my regnety_032 weights to use higher test size. --- timm/models/regnet.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 26d8650b..3b7dba52 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -57,12 +57,13 @@ model_cfgs = dict( ) -def _cfg(url=''): +def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'stem.conv', 'classifier': 'head.fc', + **kwargs } @@ -84,12 +85,16 @@ default_cfgs = dict( regnety_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'), regnety_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth'), regnety_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth'), - regnety_032=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth'), + regnety_032=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth', + crop_pct=1.0, test_input_size=(3, 288, 288)), regnety_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'), regnety_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'), regnety_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'), regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'), - regnety_160=_cfg(url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth'), + regnety_160=_cfg( + url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository + crop_pct=1.0, test_input_size=(3, 288, 288)), regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'), ) @@ -328,11 +333,20 @@ class RegNet(nn.Module): return x +def _filter_fn(state_dict): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + if 'model' in state_dict: + # For DeiT trained regnety_160 pretraiend model + state_dict = state_dict['model'] + return state_dict + + def _create_regnet(variant, pretrained, **kwargs): return build_model_with_cfg( RegNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant], + pretrained_filter_fn=_filter_fn, **kwargs) From 9071568f0e9045507b601eb6a9950fa364f24f27 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 1 Apr 2021 17:22:27 -0700 Subject: [PATCH 20/24] Add weights for SE NFNet-L0 model, rename nfnet_l0b -> nfnet_l0. 82.75 top-1 @ 288. Add nfnet_l1 model def for training. --- timm/models/nfnet.py | 48 +++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index c037c7ec..537fb15d 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -100,14 +100,16 @@ default_cfgs = dict( nfnet_f7s=_dcfg( url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)), - nfnet_l0a=_dcfg( - url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288)), - nfnet_l0b=_dcfg( - url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288)), + nfnet_l0=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nfnet_l0_ra2-45c6688d.pth', + pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0), eca_nfnet_l0=_dcfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l0_ra2-e3e9ac50.pth', hf_hub='timm/eca_nfnet_l0', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0), + eca_nfnet_l1=_dcfg( + url='', + pool_size=(7, 7), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0), nf_regnet_b0=_dcfg( url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'), @@ -232,15 +234,15 @@ model_cfgs = dict( nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'), nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'), - # Experimental 'light' versions of nfnet-f that are little leaner - nfnet_l0a=_nfnet_cfg( - depths=(1, 2, 6, 3), channels=(256, 512, 1280, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25, - attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'), - nfnet_l0b=_nfnet_cfg( - depths=(1, 2, 6, 3), channels=(256, 512, 1536, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25, + # Experimental 'light' versions of NFNet-F that are little leaner + nfnet_l0=_nfnet_cfg( + depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25, attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'), eca_nfnet_l0=_nfnet_cfg( - depths=(1, 2, 6, 3), channels=(256, 512, 1536, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25, + depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + eca_nfnet_l1=_nfnet_cfg( + depths=(2, 4, 12, 6), feat_mult=2, group_size=64, bottle_ratio=0.25, attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), # EffNet influenced RegNet defs. @@ -789,29 +791,29 @@ def nfnet_f7s(pretrained=False, **kwargs): @register_model -def nfnet_l0a(pretrained=False, **kwargs): - """ NFNet-L0a w/ SiLU - My experimental 'light' model w/ 1280 width stage 3, 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio - """ - return _create_normfreenet('nfnet_l0a', pretrained=pretrained, **kwargs) - - -@register_model -def nfnet_l0b(pretrained=False, **kwargs): +def nfnet_l0(pretrained=False, **kwargs): """ NFNet-L0b w/ SiLU - My experimental 'light' model w/ 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio + My experimental 'light' model w/ F0 repeats, 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio """ - return _create_normfreenet('nfnet_l0b', pretrained=pretrained, **kwargs) + return _create_normfreenet('nfnet_l0', pretrained=pretrained, **kwargs) @register_model def eca_nfnet_l0(pretrained=False, **kwargs): """ ECA-NFNet-L0 w/ SiLU - My experimental 'light' model w/ 1.5x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + My experimental 'light' model w/ F0 repeats, 1.5x final_conv mult, 64 group_size, .25 bottleneck & ECA attn """ return _create_normfreenet('eca_nfnet_l0', pretrained=pretrained, **kwargs) +@register_model +def eca_nfnet_l1(pretrained=False, **kwargs): + """ ECA-NFNet-L1 w/ SiLU + My experimental 'light' model w/ F1 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l1', pretrained=pretrained, **kwargs) + + @register_model def nf_regnet_b0(pretrained=False, **kwargs): """ Normalization-Free RegNet-B0 From acbd698c83ef020c0f0ca3471e3945bd8611ebe3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 1 Apr 2021 17:49:05 -0700 Subject: [PATCH 21/24] Update README.md with updates. Small tweak to head_dist handling. --- README.md | 18 ++++++++++++++++++ timm/models/vision_transformer.py | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a644d0d0..3f212d7d 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,22 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### April 1, 2021 +* Add snazzy `benchmark.py` script for bulk `timm` model benchmarking of train and/or inference +* Add Pooling-based Vision Transformer (PiT) models (from https://github.com/naver-ai/pit) + * Merged distilled variant into main for torchscript compatibility + * Some `timm` cleanup/style tweaks and weights have hub download support +* Cleanup Vision Transformer (ViT) models + * Merge distilled (DeiT) model into main so that torchscript can work + * Support updated weight init (defaults to old still) that closer matches original JAX impl (possibly better training from scratch) + * Separate hybrid model defs into different file and add several new model defs to fiddle with, support patch_size != 1 for hybrids + * Fix fine-tuning num_class changes (PiT and ViT) and pos_embed resizing (Vit) with distilled variants + * nn.Sequential for block stack (does not break downstream compat) +* TnT (Transformer-in-Transformer) models contributed by author (from https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT) +* Add RegNetY-160 weights from DeiT teacher model +* Add new NFNet-L0 w/ SE attn (rename `nfnet_l0b`->`nfnet_l0`) weights 82.75 top-1 @ 288x288 +* Some fixes/improvements for TFDS dataset wrapper + ### March 17, 2021 * Add new ECA-NFNet-L0 (rename `nfnet_l0c`->`eca_nfnet_l0`) weights trained by myself. * 82.6 top-1 @ 288x288, 82.8 @ 320x320, trained at 224x224 @@ -189,6 +205,7 @@ A full version of the list below with source links can be found in the [document * NFNet-F - https://arxiv.org/abs/2102.06171 * NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692 * PNasNet - https://arxiv.org/abs/1712.00559 +* Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302 * RegNet - https://arxiv.org/abs/2003.13678 * RepVGG - https://arxiv.org/abs/2101.03697 * ResNet/ResNeXt @@ -204,6 +221,7 @@ A full version of the list below with source links can be found in the [document * ReXNet - https://arxiv.org/abs/2007.00992 * SelecSLS - https://arxiv.org/abs/1907.00837 * Selective Kernel Networks - https://arxiv.org/abs/1903.06586 +* Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112 * TResNet - https://arxiv.org/abs/2003.13630 * Vision Transformer - https://arxiv.org/abs/2010.11929 * VovNet V2 and V1 - https://arxiv.org/abs/1911.06667 diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 7a7afbff..cd73cc11 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -337,7 +337,7 @@ class VisionTransformer(nn.Module): def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - if self.head_dist is not None: + if self.num_tokens == 2: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): From bf2ca6bdf474ab0f27b4fa7cdca9348e60058f20 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 1 Apr 2021 18:11:51 -0700 Subject: [PATCH 22/24] Merge jax and original weight init --- timm/models/vision_transformer.py | 64 +++++++++++++------------------ 1 file changed, 26 insertions(+), 38 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index cd73cc11..81f8ae9f 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -289,40 +289,19 @@ class VisionTransformer(nn.Module): assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. trunc_normal_(self.pos_embed, std=.02) + if self.dist_token is not None: + trunc_normal_(self.dist_token, std=.02) if weight_init.startswith('jax'): # leave cls token as zeros to match jax impl for n, m in self.named_modules(): - _init_weights_jax(m, n, head_bias=head_bias) + _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) else: trunc_normal_(self.cls_token, std=.02) - if self.dist_token is not None: - trunc_normal_(self.dist_token, std=.02) - for n, m in self.named_modules(): - self._init_weights(m, n, head_bias=head_bias) - - def _init_weights(self, m, n: str = '', head_bias: float = 0., init_conv=False): - # This impl does not exactly match the official JAX version. - # When called w/o n, head_bias, init_conv args it will behave exactly the same - # as my original init for compatibility with downstream use cases (ie DeiT). - if isinstance(m, nn.Linear): - if n.startswith('head'): - nn.init.zeros_(m.weight) - nn.init.constant_(m.bias, head_bias) - elif n.startswith('pre_logits'): - lecun_normal_(m.weight) - nn.init.zeros_(m.bias) - else: - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif init_conv and isinstance(m, nn.Conv2d): - # NOTE conv was left to pytorch default init originally - lecun_normal_(m.weight) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.LayerNorm): - nn.init.zeros_(m.bias) - nn.init.ones_(m.weight) + self.apply(_init_vit_weights) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + _init_vit_weights(m) @torch.jit.ignore def no_weight_decay(self): @@ -369,9 +348,12 @@ class VisionTransformer(nn.Module): return x -def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.): - # A weight init scheme closer to the official JAX impl than my original init - # NOTE: requires module name so cannot be used via module.apply() +def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False): + """ ViT weight initialization + * When called without n, head_bias, jax_impl args it will behave exactly the same + as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). + * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl + """ if isinstance(m, nn.Linear): if n.startswith('head'): nn.init.zeros_(m.weight) @@ -380,13 +362,19 @@ def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.): lecun_normal_(m.weight) nn.init.zeros_(m.bias) else: - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - if 'mlp' in n: - nn.init.normal_(m.bias, 0, 1e-6) - else: + if jax_impl: + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + if 'mlp' in n: + nn.init.normal_(m.bias, std=1e-6) + else: + nn.init.zeros_(m.bias) + else: + trunc_normal_(m.weight, std=.02) + if m.bias is not None: nn.init.zeros_(m.bias) - elif isinstance(m, nn.Conv2d): + elif jax_impl and isinstance(m, nn.Conv2d): + # NOTE conv was left to pytorch default in my original init lecun_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) From 2bb65bd8750755112da11124c2fdc4895bd971a8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 1 Apr 2021 20:00:41 -0700 Subject: [PATCH 23/24] Wrong default_cfg pool_size for L1 --- timm/models/nfnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 537fb15d..1fa0d212 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -109,7 +109,7 @@ default_cfgs = dict( pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0), eca_nfnet_l1=_dcfg( url='', - pool_size=(7, 7), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0), + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0), nf_regnet_b0=_dcfg( url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'), From 37c71a5609707d9d1ebdfb3a7e3f26ea542f60c6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 1 Apr 2021 22:34:55 -0700 Subject: [PATCH 24/24] Some further create_optimizer_v2 tweaks, remove some redudnant code, add back safe model str. Benchmark step times per batch. --- benchmark.py | 43 ++++++++++++++++++++----------------- timm/optim/optim_factory.py | 34 ++++++++++++++--------------- train.py | 2 +- 3 files changed, 41 insertions(+), 38 deletions(-) diff --git a/benchmark.py b/benchmark.py index 5f296c24..4812d85c 100755 --- a/benchmark.py +++ b/benchmark.py @@ -217,17 +217,18 @@ class InferenceBenchmarkRunner(BenchmarkRunner): delta_fwd = _step() total_step += delta_fwd num_samples += self.batch_size - if (i + 1) % self.log_freq == 0: + num_steps = i + 1 + if num_steps % self.log_freq == 0: _logger.info( - f"Infer [{i + 1}/{self.num_bench_iter}]." + f"Infer [{num_steps}/{self.num_bench_iter}]." f" {num_samples / total_step:0.2f} samples/sec." - f" {1000 * total_step / num_samples:0.3f} ms/sample.") + f" {1000 * total_step / num_steps:0.3f} ms/step.") t_run_end = self.time_fn(True) t_run_elapsed = t_run_end - t_run_start results = dict( samples_per_sec=round(num_samples / t_run_elapsed, 2), - step_time=round(1000 * total_step / num_samples, 3), + step_time=round(1000 * total_step / self.num_bench_iter, 3), batch_size=self.batch_size, img_size=self.input_size[-1], param_count=round(self.param_count / 1e6, 2), @@ -235,7 +236,7 @@ class InferenceBenchmarkRunner(BenchmarkRunner): _logger.info( f"Inference benchmark of {self.model_name} done. " - f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/sample") + f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/step") return results @@ -254,8 +255,8 @@ class TrainBenchmarkRunner(BenchmarkRunner): self.optimizer = create_optimizer_v2( self.model, - opt_name=kwargs.pop('opt', 'sgd'), - lr=kwargs.pop('lr', 1e-4)) + optimizer_name=kwargs.pop('opt', 'sgd'), + learning_rate=kwargs.pop('lr', 1e-4)) def _gen_target(self, batch_size): return torch.empty( @@ -309,23 +310,24 @@ class TrainBenchmarkRunner(BenchmarkRunner): total_fwd += delta_fwd total_bwd += delta_bwd total_opt += delta_opt - if (i + 1) % self.log_freq == 0: + num_steps = (i + 1) + if num_steps % self.log_freq == 0: total_step = total_fwd + total_bwd + total_opt _logger.info( - f"Train [{i + 1}/{self.num_bench_iter}]." + f"Train [{num_steps}/{self.num_bench_iter}]." f" {num_samples / total_step:0.2f} samples/sec." - f" {1000 * total_fwd / num_samples:0.3f} ms/sample fwd," - f" {1000 * total_bwd / num_samples:0.3f} ms/sample bwd," - f" {1000 * total_opt / num_samples:0.3f} ms/sample opt." + f" {1000 * total_fwd / num_steps:0.3f} ms/step fwd," + f" {1000 * total_bwd / num_steps:0.3f} ms/step bwd," + f" {1000 * total_opt / num_steps:0.3f} ms/step opt." ) total_step = total_fwd + total_bwd + total_opt t_run_elapsed = self.time_fn() - t_run_start results = dict( samples_per_sec=round(num_samples / t_run_elapsed, 2), - step_time=round(1000 * total_step / num_samples, 3), - fwd_time=round(1000 * total_fwd / num_samples, 3), - bwd_time=round(1000 * total_bwd / num_samples, 3), - opt_time=round(1000 * total_opt / num_samples, 3), + step_time=round(1000 * total_step / self.num_bench_iter, 3), + fwd_time=round(1000 * total_fwd / self.num_bench_iter, 3), + bwd_time=round(1000 * total_bwd / self.num_bench_iter, 3), + opt_time=round(1000 * total_opt / self.num_bench_iter, 3), batch_size=self.batch_size, img_size=self.input_size[-1], param_count=round(self.param_count / 1e6, 2), @@ -337,15 +339,16 @@ class TrainBenchmarkRunner(BenchmarkRunner): delta_step = _step(False) num_samples += self.batch_size total_step += delta_step - if (i + 1) % self.log_freq == 0: + num_steps = (i + 1) + if num_steps % self.log_freq == 0: _logger.info( - f"Train [{i + 1}/{self.num_bench_iter}]." + f"Train [{num_steps}/{self.num_bench_iter}]." f" {num_samples / total_step:0.2f} samples/sec." - f" {1000 * total_step / num_samples:0.3f} ms/sample.") + f" {1000 * total_step / num_steps:0.3f} ms/step.") t_run_elapsed = self.time_fn() - t_run_start results = dict( samples_per_sec=round(num_samples / t_run_elapsed, 2), - step_time=round(1000 * total_step / num_samples, 3), + step_time=round(1000 * total_step / self.num_bench_iter, 3), batch_size=self.batch_size, img_size=self.input_size[-1], param_count=round(self.param_count / 1e6, 2), diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index a4844f14..a10607cb 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -44,14 +44,17 @@ def optimizer_kwargs(cfg): """ cfg/argparse to kwargs helper Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn. """ - kwargs = dict(opt_name=cfg.opt, lr=cfg.lr, weight_decay=cfg.weight_decay) + kwargs = dict( + optimizer_name=cfg.opt, + learning_rate=cfg.lr, + weight_decay=cfg.weight_decay, + momentum=cfg.momentum) if getattr(cfg, 'opt_eps', None) is not None: kwargs['eps'] = cfg.opt_eps if getattr(cfg, 'opt_betas', None) is not None: kwargs['betas'] = cfg.opt_betas if getattr(cfg, 'opt_args', None) is not None: kwargs.update(cfg.opt_args) - kwargs['momentum'] = cfg.momentum return kwargs @@ -59,20 +62,17 @@ def create_optimizer(args, model, filter_bias_and_bn=True): """ Legacy optimizer factory for backwards compatibility. NOTE: Use create_optimizer_v2 for new code. """ - opt_args = dict(lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) - if hasattr(args, 'opt_eps') and args.opt_eps is not None: - opt_args['eps'] = args.opt_eps - if hasattr(args, 'opt_betas') and args.opt_betas is not None: - opt_args['betas'] = args.opt_betas - if hasattr(args, 'opt_args') and args.opt_args is not None: - opt_args.update(args.opt_args) - return create_optimizer_v2(model, opt_name=args.opt, filter_bias_and_bn=filter_bias_and_bn, **opt_args) + return create_optimizer_v2( + model, + **optimizer_kwargs(cfg=args), + filter_bias_and_bn=filter_bias_and_bn, + ) def create_optimizer_v2( model: nn.Module, - opt_name: str = 'sgd', - lr: Optional[float] = None, + optimizer_name: str = 'sgd', + learning_rate: Optional[float] = None, weight_decay: float = 0., momentum: float = 0.9, filter_bias_and_bn: bool = True, @@ -86,8 +86,8 @@ def create_optimizer_v2( Args: model (nn.Module): model containing parameters to optimize - opt_name: name of optimizer to create - lr: initial learning rate + optimizer_name: name of optimizer to create + learning_rate: initial learning rate weight_decay: weight decay to apply in optimizer momentum: momentum for momentum based optimizers (others may use betas via kwargs) filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay @@ -96,7 +96,7 @@ def create_optimizer_v2( Returns: Optimizer """ - opt_lower = opt_name.lower() + opt_lower = optimizer_name.lower() if weight_decay and filter_bias_and_bn: skip = {} if hasattr(model, 'no_weight_decay'): @@ -108,7 +108,7 @@ def create_optimizer_v2( if 'fused' in opt_lower: assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' - opt_args = dict(lr=lr, weight_decay=weight_decay, **kwargs) + opt_args = dict(lr=learning_rate, weight_decay=weight_decay, **kwargs) opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': @@ -132,7 +132,7 @@ def create_optimizer_v2( elif opt_lower == 'adadelta': optimizer = optim.Adadelta(parameters, **opt_args) elif opt_lower == 'adafactor': - if not lr: + if not learning_rate: opt_args['lr'] = None optimizer = Adafactor(parameters, **opt_args) elif opt_lower == 'adahessian': diff --git a/train.py b/train.py index e1f308ae..89ade4a1 100755 --- a/train.py +++ b/train.py @@ -552,7 +552,7 @@ def main(): else: exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), - args.model, + safe_model_name(args.model), str(data_config['input_size'][-1]) ]) output_dir = get_outdir(args.output if args.output else './output/train', exp_name)