Merge remote-tracking branch 'origin/benchmark-fixes-vit_hybrids' into pit_and_vit_update

pull/533/head
Ross Wightman 4 years ago
commit a5310a3451

@ -0,0 +1,470 @@
#!/usr/bin/env python3
""" Model Benchmark Script
An inference and train step benchmark script for timm models.
Hacked together by Ross Wightman (https://github.com/rwightman)
"""
import argparse
import os
import csv
import json
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
from collections import OrderedDict
from contextlib import suppress
from functools import partial
from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer
from timm.data import resolve_data_config
from timm.utils import AverageMeter, setup_default_logging
has_apex = False
try:
from apex import amp
has_apex = True
except ImportError:
pass
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate')
parser = argparse.ArgumentParser(description='PyTorch Benchmark')
# benchmark specific args
parser.add_argument('--model-list', metavar='NAME', default='',
help='txt file based list of model names to benchmark')
parser.add_argument('--bench', default='both', type=str,
help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'inference'")
parser.add_argument('--detail', action='store_true', default=False,
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
# common inference / train args
parser.add_argument('--model', '-m', metavar='NAME', default='resnet50',
help='model architecture (default: resnet50)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--num-classes', type=int, default=None,
help='Number classes in dataset')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--amp', action='store_true', default=False,
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
# train optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.0001,
help='weight decay (default: 0.0001)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--clip-mode', type=str, default='norm',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
# model regularization / loss params that impact model or loss fn
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
def timestamp(sync=False):
return time.perf_counter()
def cuda_timestamp(sync=False, device=None):
if sync:
torch.cuda.synchronize(device=device)
return time.perf_counter()
def count_params(model):
return sum([m.numel() for m in model.parameters()])
class BenchmarkRunner:
def __init__(self, model_name, detail=False, device='cuda', torchscript=False, **kwargs):
self.model_name = model_name
self.detail = detail
self.device = device
self.model = create_model(
model_name,
num_classes=kwargs.pop('num_classes', None),
in_chans=3,
global_pool=kwargs.pop('gp', 'fast'),
scriptable=torchscript).to(device=self.device)
self.num_classes = self.model.num_classes
self.param_count = count_params(self.model)
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
self.channels_last = kwargs.pop('channels_last', False)
self.use_amp = kwargs.pop('use_amp', '')
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp == 'native' else suppress
if torchscript:
self.model = torch.jit.script(self.model)
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=True)
self.input_size = data_config['input_size']
self.batch_size = kwargs.pop('batch_size', 256)
self.example_inputs = None
self.num_warm_iter = 10
self.num_bench_iter = 50
self.log_freq = 10
if 'cuda' in self.device:
self.time_fn = partial(cuda_timestamp, device=self.device)
else:
self.time_fn = timestamp
def _init_input(self):
self.example_inputs = torch.randn((self.batch_size,) + self.input_size, device=self.device)
if self.channels_last:
self.example_inputs = self.example_inputs.contiguous(memory_format=torch.channels_last)
class InferenceBenchmarkRunner(BenchmarkRunner):
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
self.model.eval()
if self.use_amp == 'apex':
self.model = amp.initialize(self.model, opt_level='O1')
if self.channels_last:
self.model = self.model.to(memory_format=torch.channels_last)
def run(self):
def _step():
t_step_start = self.time_fn()
with self.amp_autocast():
output = self.model(self.example_inputs)
t_step_end = self.time_fn(True)
return t_step_end - t_step_start
_logger.info(
f'Running inference benchmark on {self.model_name} for {self.num_bench_iter} steps w/ '
f'input size {self.input_size} and batch size {self.batch_size}.')
with torch.no_grad():
self._init_input()
for _ in range(self.num_warm_iter):
_step()
total_step = 0.
num_samples = 0
t_run_start = self.time_fn()
for i in range(self.num_bench_iter):
delta_fwd = _step()
total_step += delta_fwd
num_samples += self.batch_size
if (i + 1) % self.log_freq == 0:
_logger.info(
f"Infer [{i + 1}/{self.num_bench_iter}]."
f" {num_samples / total_step:0.2f} samples/sec."
f" {1000 * total_step / num_samples:0.3f} ms/sample.")
t_run_end = self.time_fn(True)
t_run_elapsed = t_run_end - t_run_start
results = dict(
samples_per_sec=round(num_samples / t_run_elapsed, 2),
step_time=round(1000 * total_step / num_samples, 3),
batch_size=self.batch_size,
img_size=self.input_size[-1],
param_count=round(self.param_count / 1e6, 2),
)
_logger.info(
f"Inference benchmark of {self.model_name} done. "
f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/sample")
return results
class TrainBenchmarkRunner(BenchmarkRunner):
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
self.model.train()
if kwargs.pop('smoothing', 0) > 0:
self.loss = nn.CrossEntropyLoss().to(self.device)
else:
self.loss = nn.CrossEntropyLoss().to(self.device)
self.target_shape = tuple()
self.optimizer = create_optimizer(
self.model,
opt_name=kwargs.pop('opt', 'sgd'),
lr=kwargs.pop('lr', 1e-4))
if self.use_amp == 'apex':
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1')
if self.channels_last:
self.model = self.model.to(memory_format=torch.channels_last)
def _gen_target(self, batch_size):
return torch.empty(
(batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes)
def run(self):
def _step(detail=False):
self.optimizer.zero_grad() # can this be ignored?
t_start = self.time_fn()
t_fwd_end = t_start
t_bwd_end = t_start
with self.amp_autocast():
output = self.model(self.example_inputs)
if isinstance(output, tuple):
output = output[0]
if detail:
t_fwd_end = self.time_fn(True)
target = self._gen_target(output.shape[0])
self.loss(output, target).backward()
if detail:
t_bwd_end = self.time_fn(True)
self.optimizer.step()
t_end = self.time_fn(True)
if detail:
delta_fwd = t_fwd_end - t_start
delta_bwd = t_bwd_end - t_fwd_end
delta_opt = t_end - t_bwd_end
return delta_fwd, delta_bwd, delta_opt
else:
delta_step = t_end - t_start
return delta_step
_logger.info(
f'Running train benchmark on {self.model_name} for {self.num_bench_iter} steps w/ '
f'input size {self.input_size} and batch size {self.batch_size}.')
self._init_input()
for _ in range(self.num_warm_iter):
_step()
t_run_start = self.time_fn()
if self.detail:
total_fwd = 0.
total_bwd = 0.
total_opt = 0.
num_samples = 0
for i in range(self.num_bench_iter):
delta_fwd, delta_bwd, delta_opt = _step(True)
num_samples += self.batch_size
total_fwd += delta_fwd
total_bwd += delta_bwd
total_opt += delta_opt
if (i + 1) % self.log_freq == 0:
total_step = total_fwd + total_bwd + total_opt
_logger.info(
f"Train [{i + 1}/{self.num_bench_iter}]."
f" {num_samples / total_step:0.2f} samples/sec."
f" {1000 * total_fwd / num_samples:0.3f} ms/sample fwd,"
f" {1000 * total_bwd / num_samples:0.3f} ms/sample bwd,"
f" {1000 * total_opt / num_samples:0.3f} ms/sample opt."
)
total_step = total_fwd + total_bwd + total_opt
t_run_elapsed = self.time_fn() - t_run_start
results = dict(
samples_per_sec=round(num_samples / t_run_elapsed, 2),
step_time=round(1000 * total_step / num_samples, 3),
fwd_time=round(1000 * total_fwd / num_samples, 3),
bwd_time=round(1000 * total_bwd / num_samples, 3),
opt_time=round(1000 * total_opt / num_samples, 3),
batch_size=self.batch_size,
img_size=self.input_size[-1],
param_count=round(self.param_count / 1e6, 2),
)
else:
total_step = 0.
num_samples = 0
for i in range(self.num_bench_iter):
delta_step = _step(False)
num_samples += self.batch_size
total_step += delta_step
if (i + 1) % self.log_freq == 0:
_logger.info(
f"Train [{i + 1}/{self.num_bench_iter}]."
f" {num_samples / total_step:0.2f} samples/sec."
f" {1000 * total_step / num_samples:0.3f} ms/sample.")
t_run_elapsed = self.time_fn() - t_run_start
results = dict(
samples_per_sec=round(num_samples / t_run_elapsed, 2),
step_time=round(1000 * total_step / num_samples, 3),
batch_size=self.batch_size,
param_count=round(self.param_count / 1e6, 2),
)
_logger.info(
f"Train benchmark of {self.model_name} done. "
f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/sample")
return results
def decay_batch_exp(batch_size, factor=0.5, divisor=16):
out_batch_size = batch_size * factor
if out_batch_size > divisor:
out_batch_size = (out_batch_size + 1) // divisor * divisor
else:
out_batch_size = batch_size - 1
return max(0, int(out_batch_size))
def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
batch_size = initial_batch_size
results = dict()
while batch_size >= 1:
try:
bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs)
results = bench.run()
return results
except RuntimeError as e:
torch.cuda.empty_cache()
batch_size = decay_batch_exp(batch_size)
print(f'Error: {str(e)} while running benchmark. Reducing batch size to {batch_size} for retry.')
return results
def benchmark(args):
if args.amp:
if has_native_amp:
args.native_amp = True
elif has_apex:
args.apex_amp = True
else:
_logger.warning("Neither APEX or Native Torch AMP is available.")
if args.native_amp:
args.use_amp = 'native'
_logger.info('Benchmarking in mixed precision with native PyTorch AMP.')
elif args.apex_amp:
args.use_amp = 'apex'
_logger.info('Benchmarking in mixed precision with NVIDIA APEX AMP.')
else:
args.use_amp = ''
_logger.info('Benchmarking in float32. AMP not enabled.')
bench_kwargs = vars(args).copy()
model = bench_kwargs.pop('model')
batch_size = bench_kwargs.pop('batch_size')
bench_fns = (InferenceBenchmarkRunner,)
prefixes = ('infer',)
if args.bench == 'both':
bench_fns = (
InferenceBenchmarkRunner,
TrainBenchmarkRunner
)
prefixes = ('infer', 'train')
elif args.bench == 'train':
bench_fns = TrainBenchmarkRunner,
prefixes = 'train',
model_results = OrderedDict(model=model)
for prefix, bench_fn in zip(prefixes, bench_fns):
run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs)
if prefix:
run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
model_results.update(run_results)
param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0))
model_results.setdefault('param_count', param_count)
model_results.pop('train_param_count', 0)
return model_results
def main():
setup_default_logging()
args = parser.parse_args()
model_cfgs = []
model_names = []
if args.model_list:
args.model = ''
with open(args.model_list) as f:
model_names = [line.rstrip() for line in f]
model_cfgs = [(n, None) for n in model_names]
elif args.model == 'all':
# validate all models in a list of names with pretrained checkpoints
args.pretrained = True
model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
model_cfgs = [(n, None) for n in model_names]
elif not is_model(args.model):
# model name doesn't exist, try as wildcard filter
model_names = list_models(args.model)
model_cfgs = [(n, None) for n in model_names]
if len(model_cfgs):
results_file = args.results_file or './benchmark.csv'
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
results = []
try:
for m, _ in model_cfgs:
if not m:
continue
args.model = m
r = benchmark(args)
results.append(r)
except KeyboardInterrupt as e:
pass
sort_key = 'train_samples_per_sec' if 'train' in args.bench else 'infer_samples_per_sec'
results = sorted(results, key=lambda x: x[sort_key], reverse=True)
if len(results):
write_results(results_file, results)
import json
json_str = json.dumps(results, indent=4)
print(json_str)
else:
benchmark(args)
def write_results(results_file, results):
with open(results_file, mode='w') as cf:
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
dw.writeheader()
for r in results:
dw.writerow(r)
cf.flush()
if __name__ == '__main__':
main()

@ -5,7 +5,7 @@ from .constants import *
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=True): def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
new_config = {} new_config = {}
default_cfg = default_cfg default_cfg = default_cfg
if not default_cfg and model is not None and hasattr(model, 'default_cfg'): if not default_cfg and model is not None and hasattr(model, 'default_cfg'):

@ -73,12 +73,13 @@ class IterableImageDataset(data.IterableDataset):
batch_size=None, batch_size=None,
class_map='', class_map='',
load_bytes=False, load_bytes=False,
repeats=0,
transform=None, transform=None,
): ):
assert parser is not None assert parser is not None
if isinstance(parser, str): if isinstance(parser, str):
self.parser = create_parser( self.parser = create_parser(
parser, root=root, split=split, is_training=is_training, batch_size=batch_size) parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats)
else: else:
self.parser = parser self.parser = parser
self.transform = transform self.transform = transform

@ -23,6 +23,7 @@ def create_dataset(name, root, split='validation', search_split=True, is_trainin
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs) root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
else: else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier
if search_split and os.path.isdir(root): if search_split and os.path.isdir(root):
root = _search_split(root, split) root = _search_split(root, split)
ds = ImageDataset(root, parser=name, **kwargs) ds = ImageDataset(root, parser=name, **kwargs)

@ -29,6 +29,11 @@ SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
PREFETCH_SIZE = 4096 # samples to prefetch PREFETCH_SIZE = 4096 # samples to prefetch
def even_split_indices(split, n, num_samples):
partitions = [round(i * num_samples / n) for i in range(n + 1)]
return [f"{split}[{partitions[i]}:{partitions[i+1]}]" for i in range(n)]
class ParserTfds(Parser): class ParserTfds(Parser):
""" Wrap Tensorflow Datasets for use in PyTorch """ Wrap Tensorflow Datasets for use in PyTorch
@ -52,7 +57,7 @@ class ParserTfds(Parser):
components. components.
""" """
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None): def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0):
super().__init__() super().__init__()
self.root = root self.root = root
self.split = split self.split = split
@ -62,6 +67,8 @@ class ParserTfds(Parser):
assert batch_size is not None,\ assert batch_size is not None,\
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
self.batch_size = batch_size self.batch_size = batch_size
self.repeats = repeats
self.subsplit = None
self.builder = tfds.builder(name, data_dir=root) self.builder = tfds.builder(name, data_dir=root)
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call # NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
@ -95,6 +102,7 @@ class ParserTfds(Parser):
if worker_info is not None: if worker_info is not None:
self.worker_info = worker_info self.worker_info = worker_info
num_workers = worker_info.num_workers num_workers = worker_info.num_workers
global_num_workers = self.dist_num_replicas * num_workers
worker_id = worker_info.id worker_id = worker_info.id
# FIXME I need to spend more time figuring out the best way to distribute/split data across # FIXME I need to spend more time figuring out the best way to distribute/split data across
@ -114,19 +122,31 @@ class ParserTfds(Parser):
# split = split + '[{}:]'.format(start) # split = split + '[{}:]'.format(start)
# else: # else:
# split = split + '[{}:{}]'.format(start, start + split_size) # split = split + '[{}:{}]'.format(start, start + split_size)
if not self.is_training and '[' not in self.split:
input_context = tf.distribute.InputContext( # If not training, and split doesn't define a subsplit, manually split the dataset
num_input_pipelines=self.dist_num_replicas * num_workers, # for more even samples / worker
input_pipeline_id=self.dist_rank * num_workers + worker_id, self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[
num_replicas_in_sync=self.dist_num_replicas # FIXME does this have any impact? self.dist_rank * num_workers + worker_id]
)
if self.subsplit is None:
read_config = tfds.ReadConfig(input_context=input_context) input_context = tf.distribute.InputContext(
ds = self.builder.as_dataset(split=split, shuffle_files=self.shuffle, read_config=read_config) num_input_pipelines=self.dist_num_replicas * num_workers,
input_pipeline_id=self.dist_rank * num_workers + worker_id,
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
)
else:
input_context = None
read_config = tfds.ReadConfig(
shuffle_seed=42,
shuffle_reshuffle_each_iteration=True,
input_context=input_context)
ds = self.builder.as_dataset(
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config)
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
ds.options().experimental_threading.max_intra_op_parallelism = 1 ds.options().experimental_threading.max_intra_op_parallelism = 1
if self.is_training: if self.is_training or self.repeats > 1:
# to prevent excessive drop_last batch behaviour w/ IterableDatasets # to prevent excessive drop_last batch behaviour w/ IterableDatasets
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually ds = ds.repeat() # allow wrap around and break iteration manually
@ -143,7 +163,7 @@ class ParserTfds(Parser):
# This adds extra samples and will slightly alter validation results. # This adds extra samples and will slightly alter validation results.
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
# batches are produced (underlying tfds iter wraps around) # batches are produced (underlying tfds iter wraps around)
target_sample_count = math.ceil(self.num_samples / self._num_pipelines) target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines)
if self.is_training: if self.is_training:
# round up to nearest batch_size per worker-replica # round up to nearest batch_size per worker-replica
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
@ -160,8 +180,8 @@ class ParserTfds(Parser):
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count: if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:
# Validation batch padding only done for distributed training where results are reduced across nodes. # Validation batch padding only done for distributed training where results are reduced across nodes.
# For single process case, it won't matter if workers return different batch sizes. # For single process case, it won't matter if workers return different batch sizes.
# FIXME this needs more testing, possible for sharding / split api to cause differences of > 1? # FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this
assert target_sample_count - sample_count == 1 # should only be off by 1 or sharding is not optimal # approach is not optimal
yield img, sample['label'] # yield prev sample again yield img, sample['label'] # yield prev sample again
sample_count += 1 sample_count += 1
@ -176,7 +196,7 @@ class ParserTfds(Parser):
def __len__(self): def __len__(self):
# this is just an estimate and does not factor in extra samples added to pad batches based on # this is just an estimate and does not factor in extra samples added to pad batches based on
# complete worker & replica info (not available until init in dataloader). # complete worker & replica info (not available until init in dataloader).
return math.ceil(self.num_samples / self.dist_num_replicas) return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
def _filename(self, index, basename=False, absolute=False): def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to samples assert False, "Not supported" # no random access to samples

@ -29,6 +29,7 @@ from .tnt import *
from .tresnet import * from .tresnet import *
from .vgg import * from .vgg import *
from .vision_transformer import * from .vision_transformer import *
from .vision_transformer_hybrid import *
from .vovnet import * from .vovnet import *
from .xception import * from .xception import *
from .xception_aligned import * from .xception_aligned import *

@ -31,4 +31,4 @@ from .split_attn import SplitAttnConv2d
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_

@ -2,6 +2,8 @@ import torch
import math import math
import warnings import warnings
from torch.nn.init import _calculate_fan_in_and_fan_out
def _no_grad_trunc_normal_(tensor, mean, std, a, b): def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW # Cut & paste from PyTorch official master until it's in a few official releases - RW
@ -58,3 +60,30 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
>>> nn.init.trunc_normal_(w) >>> nn.init.trunc_normal_(w)
""" """
return _no_grad_trunc_normal_(tensor, mean, std, a, b) return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == 'fan_in':
denom = fan_in
elif mode == 'fan_out':
denom = fan_out
elif mode == 'fan_avg':
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
elif distribution == "normal":
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')

@ -274,7 +274,9 @@ class ResNetStage(nn.Module):
return x return x
def create_stem(in_chs, out_chs, stem_type='', preact=True, conv_layer=None, norm_layer=None): def create_resnetv2_stem(
in_chs, out_chs=64, stem_type='', preact=True,
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
stem = OrderedDict() stem = OrderedDict()
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same') assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
@ -322,7 +324,8 @@ class ResNetV2(nn.Module):
self.feature_info = [] self.feature_info = []
stem_chs = make_div(stem_chs * wf) stem_chs = make_div(stem_chs * wf)
self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer) self.stem = create_resnetv2_stem(
in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
stem_feat = ('stem.conv3' if 'deep' in stem_type else 'stem.conv') if preact else 'stem.norm' stem_feat = ('stem.conv3' if 'deep' in stem_type else 'stem.conv') if preact else 'stem.norm'
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat)) self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))

@ -29,9 +29,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_ from .layers import DropPath, to_2tuple, trunc_normal_, lecun_normal_
from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2
from .registry import register_model from .registry import register_model
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -98,25 +96,21 @@ default_cfgs = {
hf_hub='timm/vit_huge_patch14_224_in21k', hf_hub='timm/vit_huge_patch14_224_in21k',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
# hybrid models (weights ported from official Google JAX impl)
'vit_base_resnet50_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'),
'vit_base_resnet50_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'),
# hybrid models (my experiments)
'vit_small_resnet26d_224': _cfg(),
'vit_small_resnet50d_s3_224': _cfg(),
'vit_base_resnet26d_224': _cfg(),
'vit_base_resnet50d_224': _cfg(),
# deit models (FB weights) # deit models (FB weights)
'vit_deit_tiny_patch16_224': _cfg( 'vit_deit_tiny_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
'vit_deit_tiny_patch16_224_in21k': _cfg(num_classes=21843),
'vit_deit_tiny_patch16_384': _cfg(input_size=(3, 384, 384)),
'vit_deit_small_patch16_224': _cfg( 'vit_deit_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
'vit_deit_small_patch16_224_in21k': _cfg(num_classes=21843),
'vit_deit_small_patch16_384': _cfg(input_size=(3, 384, 384)),
'vit_deit_small_patch32_224': _cfg(),
'vit_deit_small_patch32_224_in21k': _cfg(num_classes=21843),
'vit_deit_small_patch32_384': _cfg(input_size=(3, 384, 384)),
'vit_deit_base_patch16_224': _cfg( 'vit_deit_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',), url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
'vit_deit_base_patch16_384': _cfg( 'vit_deit_base_patch16_384': _cfg(
@ -161,7 +155,6 @@ class Attention(nn.Module):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5 self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
@ -231,17 +224,17 @@ class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding """ CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim. Extract feature map from CNN, flatten, project to embedding dim.
""" """
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
super().__init__() super().__init__()
assert isinstance(backbone, nn.Module) assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size self.img_size = img_size
self.patch_size = patch_size
self.backbone = backbone self.backbone = backbone
if feature_size is None: if feature_size is None:
with torch.no_grad(): with torch.no_grad():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature # NOTE Most reliable way of determining output dims is to run forward pass
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training = backbone.training training = backbone.training
if training: if training:
backbone.eval() backbone.eval()
@ -257,8 +250,9 @@ class HybridEmbed(nn.Module):
feature_dim = self.backbone.feature_info.channels()[-1] feature_dim = self.backbone.feature_info.channels()[-1]
else: else:
feature_dim = self.backbone.num_features feature_dim = self.backbone.num_features
self.num_patches = feature_size[0] * feature_size[1] assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
self.proj = nn.Conv2d(feature_dim, embed_dim, 1) self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1]
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x): def forward(self, x):
x = self.backbone(x) x = self.backbone(x)
@ -277,10 +271,11 @@ class VisionTransformer(nn.Module):
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
- https://arxiv.org/abs/2012.12877 - https://arxiv.org/abs/2012.12877
""" """
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None,
weight_init=''): act_layer=None, weight_init=''):
""" """
Args: Args:
img_size (int, tuple): input image size img_size (int, tuple): input image size
@ -307,10 +302,12 @@ class VisionTransformer(nn.Module):
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1 self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
patch_size = patch_size or (1 if hybrid_backbone is not None else 16)
if hybrid_backbone is not None: if hybrid_backbone is not None:
self.patch_embed = HybridEmbed( self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) hybrid_backbone, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
else: else:
self.patch_embed = PatchEmbed( self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
@ -325,7 +322,7 @@ class VisionTransformer(nn.Module):
self.blocks = nn.Sequential(*[ self.blocks = nn.Sequential(*[
Block( Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)]) for i in range(depth)])
self.norm = norm_layer(embed_dim) self.norm = norm_layer(embed_dim)
@ -344,20 +341,44 @@ class VisionTransformer(nn.Module):
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
if num_classes > 0 and distilled else nn.Identity() if num_classes > 0 and distilled else nn.Identity()
# Weight init
assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02) if weight_init.startswith('jax'):
if self.dist_token is not None: # leave cls token as zeros to match jax impl
trunc_normal_(self.dist_token, std=.02) for n, m in self.named_modules():
self.apply(self._init_weights) _init_weights_jax(m, n, head_bias=head_bias)
else:
def _init_weights(self, m): trunc_normal_(self.cls_token, std=.02)
if self.dist_token is not None:
trunc_normal_(self.dist_token, std=.02)
for n, m in self.named_modules():
self._init_weights(m, n, head_bias=head_bias)
def _init_weights(self, m, n: str = '', head_bias: float = 0., init_conv=False):
# This impl does not exactly match the official JAX version.
# When called w/o n, head_bias, init_conv args it will behave exactly the same
# as my original init for compatibility with downstream use cases (ie DeiT).
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02) if n.startswith('head'):
if isinstance(m, nn.Linear) and m.bias is not None: nn.init.zeros_(m.weight)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, head_bias)
elif n.startswith('pre_logits'):
lecun_normal_(m.weight)
nn.init.zeros_(m.bias)
else:
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif init_conv and isinstance(m, nn.Conv2d):
# NOTE conv was left to pytorch default init originally
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm): elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0) nn.init.zeros_(m.bias)
nn.init.constant_(m.weight, 1.0) nn.init.ones_(m.weight)
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
@ -404,6 +425,32 @@ class VisionTransformer(nn.Module):
return x return x
def _init_weights_jax(m: nn.Module, n: str, head_bias: float = 0.):
# A weight init scheme closer to the official JAX impl than my original init
# NOTE: requires module name so cannot be used via module.apply()
if isinstance(m, nn.Linear):
if n.startswith('head'):
nn.init.zeros_(m.weight)
nn.init.constant_(m.bias, head_bias)
elif n.startswith('pre_logits'):
lecun_normal_(m.weight)
nn.init.zeros_(m.bias)
else:
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
if 'mlp' in n:
nn.init.normal_(m.bias, 0, 1e-6)
else:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
def resize_pos_embed(posemb, posemb_new, num_tokens=1): def resize_pos_embed(posemb, posemb_new, num_tokens=1):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from # Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
@ -411,7 +458,7 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1):
ntok_new = posemb_new.shape[1] ntok_new = posemb_new.shape[1]
if num_tokens: if num_tokens:
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
ntok_new -= 1 ntok_new -= num_tokens
else: else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0] posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid))) gs_old = int(math.sqrt(len(posemb_grid)))
@ -474,7 +521,11 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
@register_model @register_model
def vit_small_patch16_224(pretrained=False, **kwargs): def vit_small_patch16_224(pretrained=False, **kwargs):
""" My custom 'small' ViT model. Depth=8, heads=8= mlp_ratio=3.""" """ My custom 'small' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.
NOTE:
* this differs from the DeiT based 'small' definitions with embed_dim=384, depth=12, num_heads=6
* this model does not have a bias for QKV (unlike the official ViT and DeiT models)
"""
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs) qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
@ -620,92 +671,80 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
@register_model @register_model
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-1k weights from https://github.com/facebookresearch/deit.
""" """
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
backbone = ResNetV2( model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
preact=False, stem_type='same', conv_layer=StdConv2dSame)
model_kwargs = dict(
embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone,
representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_base_resnet50_384(pretrained=False, **kwargs): def vit_deit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). """ DeiT-tiny model"""
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, representation_size=192, **kwargs)
""" model = _create_vision_transformer('vit_deit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
backbone = ResNetV2(
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
preact=False, stem_type='same', conv_layer=StdConv2dSame)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_small_resnet26d_224(pretrained=False, **kwargs): def vit_deit_tiny_patch16_384(pretrained=False, **kwargs):
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights. """ DeiT-tiny model"""
""" model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) model = _create_vision_transformer('vit_deit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_small_resnet50d_s3_224(pretrained=False, **kwargs): def vit_deit_small_patch16_224(pretrained=False, **kwargs):
""" Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights. """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
""" """
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3]) model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_base_resnet26d_224(pretrained=False, **kwargs): def vit_deit_small_patch16_224_in21k(pretrained=False, **kwargs):
""" Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights. """ DeiT-small """
""" model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, representation_size=384, **kwargs)
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) model = _create_vision_transformer('vit_deit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_base_resnet50d_224(pretrained=False, **kwargs): def vit_deit_small_patch16_384(pretrained=False, **kwargs):
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights. """ DeiT-small """
""" model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) model = _create_vision_transformer('vit_deit_small_patch16_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_deit_tiny_patch16_224(pretrained=False, **kwargs): def vit_deit_small_patch32_224(pretrained=False, **kwargs):
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit. ImageNet-1k weights from https://github.com/facebookresearch/deit.
""" """
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_deit_small_patch32_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_deit_small_patch16_224(pretrained=False, **kwargs): def vit_deit_small_patch32_224_in21k(pretrained=False, **kwargs):
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). """ DeiT-small """
ImageNet-1k weights from https://github.com/facebookresearch/deit. model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, representation_size=384, **kwargs)
""" model = _create_vision_transformer('vit_deit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) return model
model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
@register_model
def vit_deit_small_patch32_384(pretrained=False, **kwargs):
""" DeiT-small """
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_deit_small_patch32_384', pretrained=pretrained, **model_kwargs)
return model return model

@ -0,0 +1,353 @@
""" Hybrid Vision Transformer (ViT) in PyTorch
A PyTorch implement of the Hybrid Vision Transformers as described in
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
- https://arxiv.org/abs/2010.11929
NOTE This relies on code in vision_transformer.py. The hybrid model definitions were moved here to
keep file sizes sane.
Hacked together by / Copyright 2020 Ross Wightman
"""
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import StdConv2dSame, StdConv2d, to_2tuple
from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, create_resnetv2_stem
from .registry import register_model
from timm.models.vision_transformer import _create_vision_transformer
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
**kwargs
}
default_cfgs = {
# hybrid in-21k models (weights ported from official Google JAX impl where they exist)
'vit_base_r50_s16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
num_classes=21843, crop_pct=0.9),
# hybrid in-1k models (weights ported from official JAX impl)
'vit_base_r50_s16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), crop_pct=1.0),
# hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones)
'vit_tiny_r_s16_p8_224': _cfg(),
'vit_tiny_r_s16_p8_384': _cfg(
input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_r_s16_p8_224': _cfg(
crop_pct=1.0),
'vit_small_r_s16_p8_384': _cfg(
input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_r20_s16_p2_224': _cfg(),
'vit_small_r20_s16_p2_384': _cfg(
input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_r20_s16_224': _cfg(),
'vit_small_r20_s16_384': _cfg(
input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_r26_s32_224': _cfg(),
'vit_small_r26_s32_384': _cfg(
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_r20_s16_224': _cfg(),
'vit_base_r20_s16_384': _cfg(
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_r26_s32_224': _cfg(),
'vit_base_r26_s32_384': _cfg(
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_r50_s16_224': _cfg(),
'vit_large_r50_s32_224': _cfg(),
'vit_large_r50_s32_384': _cfg(
input_size=(3, 384, 384), crop_pct=1.0),
# hybrid models (using timm resnet backbones)
'vit_small_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'vit_small_resnet50d_s16_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'vit_base_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'vit_base_resnet50d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
}
def _resnetv2(layers=(3, 4, 9), **kwargs):
""" ResNet-V2 backbone helper"""
padding_same = kwargs.get('padding_same', True)
if padding_same:
stem_type = 'same'
conv_layer = StdConv2dSame
else:
stem_type = ''
conv_layer = StdConv2d
if len(layers):
backbone = ResNetV2(
layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
preact=False, stem_type=stem_type, conv_layer=conv_layer)
else:
backbone = create_resnetv2_stem(
kwargs.get('in_chans', 3), stem_type=stem_type, preact=False, conv_layer=conv_layer)
return backbone
@register_model
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
model_kwargs = dict(
embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_r50_s16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r50_s16_384(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
backbone = _resnetv2((3, 4, 9), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_r50_s16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_tiny_r_s16_p8_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_tiny_r_s16_p8_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r_s16_p8_224(pretrained=False, **kwargs):
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r_s16_p8_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r_s16_p8_384(pretrained=False, **kwargs):
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r_s16_p8_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r20_s16_p2_224(pretrained=False, **kwargs):
""" R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2((2, 4), **kwargs)
model_kwargs = dict(
patch_size=2, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r20_s16_p2_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r20_s16_p2_384(pretrained=False, **kwargs):
""" R20+ViT-S/S16 w/ 2x2 Patch hybrid @ 384x384.
"""
backbone = _resnetv2((2, 4), **kwargs)
model_kwargs = dict(
embed_dim=384, patch_size=2, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r20_s16_p2_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r20_s16_224(pretrained=False, **kwargs):
""" R20+ViT-S/S16 hybrid.
"""
backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r20_s16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r20_s16_384(pretrained=False, **kwargs):
""" R20+ViT-S/S16 hybrid @ 384x384.
"""
backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r20_s16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r26_s32_224(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r26_s32_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r26_s32_384(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid @ 384x384.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r26_s32_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r20_s16_224(pretrained=False, **kwargs):
""" R20+ViT-B/S16 hybrid.
"""
backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_r20_s16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r20_s16_384(pretrained=False, **kwargs):
""" R20+ViT-B/S16 hybrid.
"""
backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_r20_s16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r26_s32_224(pretrained=False, **kwargs):
""" R26+ViT-B/S32 hybrid.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_r26_s32_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r26_s32_384(pretrained=False, **kwargs):
""" R26+ViT-B/S32 hybrid.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_r26_s32_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r50_s16_224(pretrained=False, **kwargs):
""" R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
"""
backbone = _resnetv2((3, 4, 9), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_r50_s16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_r50_s32_224(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_large_r50_s32_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(
embed_dim=768, depth=12, num_heads=12, representation_size=768, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_large_r50_s32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_r50_s32_384(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_large_r50_s32_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_resnet26d_224(pretrained=False, **kwargs):
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
"""
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_resnet50d_s16_224(pretrained=False, **kwargs):
""" Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
"""
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3])
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_resnet50d_s16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_resnet26d_224(pretrained=False, **kwargs):
""" Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
"""
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_resnet50d_224(pretrained=False, **kwargs):
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
"""
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs)
return model

@ -10,4 +10,4 @@ from .radam import RAdam
from .rmsprop_tf import RMSpropTF from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP from .sgdp import SGDP
from .optim_factory import create_optimizer from .optim_factory import create_optimizer, optimizer_kwargs

@ -1,8 +1,11 @@
""" Optimizer Factory w/ Custom Weight Decay """ Optimizer Factory w/ Custom Weight Decay
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
from typing import Optional
import torch import torch
from torch import optim as optim import torch.nn as nn
import torch.optim as optim
from .adafactor import Adafactor from .adafactor import Adafactor
from .adahessian import Adahessian from .adahessian import Adahessian
@ -37,9 +40,49 @@ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
{'params': decay, 'weight_decay': weight_decay}] {'params': decay, 'weight_decay': weight_decay}]
def create_optimizer(args, model, filter_bias_and_bn=True): def optimizer_kwargs(cfg):
opt_lower = args.opt.lower() """ cfg/argparse to kwargs helper
weight_decay = args.weight_decay Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
"""
kwargs = dict(opt_name=cfg.opt, lr=cfg.lr, weight_decay=cfg.weight_decay)
if getattr(cfg, 'opt_eps', None) is not None:
kwargs['eps'] = cfg.opt_eps
if getattr(cfg, 'opt_betas', None) is not None:
kwargs['betas'] = cfg.opt_betas
if getattr(cfg, 'opt_args', None) is not None:
kwargs.update(cfg.opt_args)
kwargs['momentum'] = cfg.momentum
return kwargs
def create_optimizer(
model: nn.Module,
opt_name: str = 'sgd',
lr: Optional[float] = None,
weight_decay: float = 0.,
momentum: float = 0.9,
filter_bias_and_bn: bool = True,
**kwargs):
""" Create an optimizer.
TODO currently the model is passed in and all parameters are selected for optimization.
For more general use an interface that allows selection of parameters to optimize and lr groups, one of:
* a filter fn interface that further breaks params into groups in a weight_decay compatible fashion
* expose the parameters interface and leave it up to caller
Args:
model (nn.Module): model containing parameters to optimize
opt_name: name of optimizer to create
lr: initial learning rate
weight_decay: weight decay to apply in optimizer
momentum: momentum for momentum based optimizers (others may use betas via kwargs)
filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
**kwargs: extra optimizer specific kwargs to pass through
Returns:
Optimizer
"""
opt_lower = opt_name.lower()
if weight_decay and filter_bias_and_bn: if weight_decay and filter_bias_and_bn:
skip = {} skip = {}
if hasattr(model, 'no_weight_decay'): if hasattr(model, 'no_weight_decay'):
@ -48,26 +91,18 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
weight_decay = 0. weight_decay = 0.
else: else:
parameters = model.parameters() parameters = model.parameters()
if 'fused' in opt_lower: if 'fused' in opt_lower:
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
opt_args = dict(lr=args.lr, weight_decay=weight_decay) opt_args = dict(lr=lr, weight_decay=weight_decay, **kwargs)
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
opt_args['eps'] = args.opt_eps
if hasattr(args, 'opt_betas') and args.opt_betas is not None:
opt_args['betas'] = args.opt_betas
if hasattr(args, 'opt_args') and args.opt_args is not None:
opt_args.update(args.opt_args)
opt_split = opt_lower.split('_') opt_split = opt_lower.split('_')
opt_lower = opt_split[-1] opt_lower = opt_split[-1]
if opt_lower == 'sgd' or opt_lower == 'nesterov': if opt_lower == 'sgd' or opt_lower == 'nesterov':
opt_args.pop('eps', None) opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'momentum': elif opt_lower == 'momentum':
opt_args.pop('eps', None) opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
elif opt_lower == 'adam': elif opt_lower == 'adam':
optimizer = optim.Adam(parameters, **opt_args) optimizer = optim.Adam(parameters, **opt_args)
elif opt_lower == 'adamw': elif opt_lower == 'adamw':
@ -79,29 +114,29 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
elif opt_lower == 'adamp': elif opt_lower == 'adamp':
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
elif opt_lower == 'sgdp': elif opt_lower == 'sgdp':
optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'adadelta': elif opt_lower == 'adadelta':
optimizer = optim.Adadelta(parameters, **opt_args) optimizer = optim.Adadelta(parameters, **opt_args)
elif opt_lower == 'adafactor': elif opt_lower == 'adafactor':
if not args.lr: if not lr:
opt_args['lr'] = None opt_args['lr'] = None
optimizer = Adafactor(parameters, **opt_args) optimizer = Adafactor(parameters, **opt_args)
elif opt_lower == 'adahessian': elif opt_lower == 'adahessian':
optimizer = Adahessian(parameters, **opt_args) optimizer = Adahessian(parameters, **opt_args)
elif opt_lower == 'rmsprop': elif opt_lower == 'rmsprop':
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
elif opt_lower == 'rmsproptf': elif opt_lower == 'rmsproptf':
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
elif opt_lower == 'novograd': elif opt_lower == 'novograd':
optimizer = NovoGrad(parameters, **opt_args) optimizer = NovoGrad(parameters, **opt_args)
elif opt_lower == 'nvnovograd': elif opt_lower == 'nvnovograd':
optimizer = NvNovoGrad(parameters, **opt_args) optimizer = NvNovoGrad(parameters, **opt_args)
elif opt_lower == 'fusedsgd': elif opt_lower == 'fusedsgd':
opt_args.pop('eps', None) opt_args.pop('eps', None)
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'fusedmomentum': elif opt_lower == 'fusedmomentum':
opt_args.pop('eps', None) opt_args.pop('eps', None)
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args)
elif opt_lower == 'fusedadam': elif opt_lower == 'fusedadam':
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
elif opt_lower == 'fusedadamw': elif opt_lower == 'fusedadamw':

@ -33,7 +33,7 @@ from timm.models import create_model, safe_model_name, resume_checkpoint, load_c
convert_splitbn_model, model_parameters convert_splitbn_model, model_parameters
from timm.utils import * from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer from timm.optim import create_optimizer, optimizer_kwargs
from timm.scheduler import create_scheduler from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler from timm.utils import ApexScaler, NativeScaler
@ -142,6 +142,8 @@ parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=200, metavar='N', parser.add_argument('--epochs', type=int, default=200, metavar='N',
help='number of epochs to train (default: 2)') help='number of epochs to train (default: 2)')
parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N', parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)') help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
@ -258,6 +260,8 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher') help='disable fast prefetcher')
parser.add_argument('--output', default='', type=str, metavar='PATH', parser.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)') help='path to output folder (default: none, current dir)')
parser.add_argument('--experiment', default='', type=str, metavar='NAME',
help='name of train experiment, name of sub-folder for output')
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"') help='Best metric (default: "top1"')
parser.add_argument('--tta', type=int, default=0, metavar='N', parser.add_argument('--tta', type=int, default=0, metavar='N',
@ -385,7 +389,7 @@ def main():
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model) model = torch.jit.script(model)
optimizer = create_optimizer(args, model) optimizer = create_optimizer(model, **optimizer_kwargs(cfg=args))
# setup automatic mixed-precision (AMP) loss scaling and op casting # setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing amp_autocast = suppress # do nothing
@ -451,7 +455,9 @@ def main():
# create the train and eval datasets # create the train and eval datasets
dataset_train = create_dataset( dataset_train = create_dataset(
args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size) args.dataset,
root=args.data_dir, split=args.train_split, is_training=True,
batch_size=args.batch_size, repeats=args.epoch_repeats)
dataset_eval = create_dataset( dataset_eval = create_dataset(
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
@ -541,13 +547,15 @@ def main():
saver = None saver = None
output_dir = '' output_dir = ''
if args.local_rank == 0: if args.local_rank == 0:
output_base = args.output if args.output else './output' if args.experiment:
exp_name = '-'.join([ exp_name = args.experiment
datetime.now().strftime("%Y%m%d-%H%M%S"), else:
safe_model_name(args.model), exp_name = '-'.join([
str(data_config['input_size'][-1]) datetime.now().strftime("%Y%m%d-%H%M%S"),
]) args.model,
output_dir = get_outdir(output_base, 'train', exp_name) str(data_config['input_size'][-1])
])
output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver( saver = CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,

@ -152,7 +152,7 @@ def validate(args):
param_count = sum([m.numel() for m in model.parameters()]) param_count = sum([m.numel() for m in model.parameters()])
_logger.info('Model %s created, param count: %d' % (args.model, param_count)) _logger.info('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(vars(args), model=model, use_test_size=True) data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
test_time_pool = False test_time_pool = False
if not args.no_test_pool: if not args.no_test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)

Loading…
Cancel
Save