Merge pull request #533 from rwightman/pit_and_vit_update
Addition of PiT models and update/cleanup of ViT, new NFNet weight, TFDS wrapper fix, few misc fixes/updatespull/537/head
commit
d5ed58d623
@ -0,0 +1,481 @@
|
||||
#!/usr/bin/env python3
|
||||
""" Model Benchmark Script
|
||||
|
||||
An inference and train step benchmark script for timm models.
|
||||
|
||||
Hacked together by Ross Wightman (https://github.com/rwightman)
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import csv
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
from collections import OrderedDict
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
|
||||
from timm.models import create_model, is_model, list_models
|
||||
from timm.optim import create_optimizer_v2
|
||||
from timm.data import resolve_data_config
|
||||
from timm.utils import AverageMeter, setup_default_logging
|
||||
|
||||
|
||||
has_apex = False
|
||||
try:
|
||||
from apex import amp
|
||||
has_apex = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
has_native_amp = False
|
||||
try:
|
||||
if getattr(torch.cuda.amp, 'autocast') is not None:
|
||||
has_native_amp = True
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
_logger = logging.getLogger('validate')
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Benchmark')
|
||||
|
||||
# benchmark specific args
|
||||
parser.add_argument('--model-list', metavar='NAME', default='',
|
||||
help='txt file based list of model names to benchmark')
|
||||
parser.add_argument('--bench', default='both', type=str,
|
||||
help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'inference'")
|
||||
parser.add_argument('--detail', action='store_true', default=False,
|
||||
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
|
||||
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
||||
help='Output csv file for validation results (summary)')
|
||||
parser.add_argument('--num-warm-iter', default=10, type=int,
|
||||
metavar='N', help='Number of warmup iterations (default: 10)')
|
||||
parser.add_argument('--num-bench-iter', default=40, type=int,
|
||||
metavar='N', help='Number of benchmark iterations (default: 40)')
|
||||
|
||||
# common inference / train args
|
||||
parser.add_argument('--model', '-m', metavar='NAME', default='resnet50',
|
||||
help='model architecture (default: resnet50)')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N', help='mini-batch size (default: 256)')
|
||||
parser.add_argument('--img-size', default=None, type=int,
|
||||
metavar='N', help='Input image dimension, uses model default if empty')
|
||||
parser.add_argument('--input-size', default=None, nargs=3, type=int,
|
||||
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
||||
parser.add_argument('--num-classes', type=int, default=None,
|
||||
help='Number classes in dataset')
|
||||
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
|
||||
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
|
||||
parser.add_argument('--channels-last', action='store_true', default=False,
|
||||
help='Use channels_last memory layout')
|
||||
parser.add_argument('--amp', action='store_true', default=False,
|
||||
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
|
||||
parser.add_argument('--precision', default='float32', type=str,
|
||||
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
|
||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='convert model torchscript for inference')
|
||||
|
||||
|
||||
# train optimizer parameters
|
||||
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||
help='Optimizer (default: "sgd"')
|
||||
parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
|
||||
help='Optimizer Epsilon (default: None, use opt default)')
|
||||
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
|
||||
help='Optimizer Betas (default: None, use opt default)')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
||||
help='Optimizer momentum (default: 0.9)')
|
||||
parser.add_argument('--weight-decay', type=float, default=0.0001,
|
||||
help='weight decay (default: 0.0001)')
|
||||
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
|
||||
help='Clip gradient norm (default: None, no clipping)')
|
||||
parser.add_argument('--clip-mode', type=str, default='norm',
|
||||
help='Gradient clipping mode. One of ("norm", "value", "agc")')
|
||||
|
||||
|
||||
# model regularization / loss params that impact model or loss fn
|
||||
parser.add_argument('--smoothing', type=float, default=0.1,
|
||||
help='Label smoothing (default: 0.1)')
|
||||
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
||||
help='Dropout rate (default: 0.)')
|
||||
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
|
||||
help='Drop path rate (default: None)')
|
||||
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
|
||||
help='Drop block rate (default: None)')
|
||||
|
||||
|
||||
def timestamp(sync=False):
|
||||
return time.perf_counter()
|
||||
|
||||
|
||||
def cuda_timestamp(sync=False, device=None):
|
||||
if sync:
|
||||
torch.cuda.synchronize(device=device)
|
||||
return time.perf_counter()
|
||||
|
||||
|
||||
def count_params(model: nn.Module):
|
||||
return sum([m.numel() for m in model.parameters()])
|
||||
|
||||
|
||||
def resolve_precision(precision: str):
|
||||
assert precision in ('amp', 'float16', 'bfloat16', 'float32')
|
||||
use_amp = False
|
||||
model_dtype = torch.float32
|
||||
data_dtype = torch.float32
|
||||
if precision == 'amp':
|
||||
use_amp = True
|
||||
elif precision == 'float16':
|
||||
model_dtype = torch.float16
|
||||
data_dtype = torch.float16
|
||||
elif precision == 'bfloat16':
|
||||
model_dtype = torch.bfloat16
|
||||
data_dtype = torch.bfloat16
|
||||
return use_amp, model_dtype, data_dtype
|
||||
|
||||
|
||||
class BenchmarkRunner:
|
||||
def __init__(
|
||||
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
|
||||
num_warm_iter=10, num_bench_iter=50, **kwargs):
|
||||
self.model_name = model_name
|
||||
self.detail = detail
|
||||
self.device = device
|
||||
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
|
||||
self.channels_last = kwargs.pop('channels_last', False)
|
||||
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
|
||||
|
||||
self.model = create_model(
|
||||
model_name,
|
||||
num_classes=kwargs.pop('num_classes', None),
|
||||
in_chans=3,
|
||||
global_pool=kwargs.pop('gp', 'fast'),
|
||||
scriptable=torchscript)
|
||||
self.model.to(
|
||||
device=self.device,
|
||||
dtype=self.model_dtype,
|
||||
memory_format=torch.channels_last if self.channels_last else None)
|
||||
self.num_classes = self.model.num_classes
|
||||
self.param_count = count_params(self.model)
|
||||
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
|
||||
if torchscript:
|
||||
self.model = torch.jit.script(self.model)
|
||||
|
||||
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=True)
|
||||
self.input_size = data_config['input_size']
|
||||
self.batch_size = kwargs.pop('batch_size', 256)
|
||||
|
||||
self.example_inputs = None
|
||||
self.num_warm_iter = num_warm_iter
|
||||
self.num_bench_iter = num_bench_iter
|
||||
self.log_freq = num_bench_iter // 5
|
||||
if 'cuda' in self.device:
|
||||
self.time_fn = partial(cuda_timestamp, device=self.device)
|
||||
else:
|
||||
self.time_fn = timestamp
|
||||
|
||||
def _init_input(self):
|
||||
self.example_inputs = torch.randn(
|
||||
(self.batch_size,) + self.input_size, device=self.device, dtype=self.data_dtype)
|
||||
if self.channels_last:
|
||||
self.example_inputs = self.example_inputs.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
|
||||
class InferenceBenchmarkRunner(BenchmarkRunner):
|
||||
|
||||
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
|
||||
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
|
||||
self.model.eval()
|
||||
|
||||
def run(self):
|
||||
def _step():
|
||||
t_step_start = self.time_fn()
|
||||
with self.amp_autocast():
|
||||
output = self.model(self.example_inputs)
|
||||
t_step_end = self.time_fn(True)
|
||||
return t_step_end - t_step_start
|
||||
|
||||
_logger.info(
|
||||
f'Running inference benchmark on {self.model_name} for {self.num_bench_iter} steps w/ '
|
||||
f'input size {self.input_size} and batch size {self.batch_size}.')
|
||||
|
||||
with torch.no_grad():
|
||||
self._init_input()
|
||||
|
||||
for _ in range(self.num_warm_iter):
|
||||
_step()
|
||||
|
||||
total_step = 0.
|
||||
num_samples = 0
|
||||
t_run_start = self.time_fn()
|
||||
for i in range(self.num_bench_iter):
|
||||
delta_fwd = _step()
|
||||
total_step += delta_fwd
|
||||
num_samples += self.batch_size
|
||||
num_steps = i + 1
|
||||
if num_steps % self.log_freq == 0:
|
||||
_logger.info(
|
||||
f"Infer [{num_steps}/{self.num_bench_iter}]."
|
||||
f" {num_samples / total_step:0.2f} samples/sec."
|
||||
f" {1000 * total_step / num_steps:0.3f} ms/step.")
|
||||
t_run_end = self.time_fn(True)
|
||||
t_run_elapsed = t_run_end - t_run_start
|
||||
|
||||
results = dict(
|
||||
samples_per_sec=round(num_samples / t_run_elapsed, 2),
|
||||
step_time=round(1000 * total_step / self.num_bench_iter, 3),
|
||||
batch_size=self.batch_size,
|
||||
img_size=self.input_size[-1],
|
||||
param_count=round(self.param_count / 1e6, 2),
|
||||
)
|
||||
|
||||
_logger.info(
|
||||
f"Inference benchmark of {self.model_name} done. "
|
||||
f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/step")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class TrainBenchmarkRunner(BenchmarkRunner):
|
||||
|
||||
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
|
||||
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
|
||||
self.model.train()
|
||||
|
||||
if kwargs.pop('smoothing', 0) > 0:
|
||||
self.loss = nn.CrossEntropyLoss().to(self.device)
|
||||
else:
|
||||
self.loss = nn.CrossEntropyLoss().to(self.device)
|
||||
self.target_shape = tuple()
|
||||
|
||||
self.optimizer = create_optimizer_v2(
|
||||
self.model,
|
||||
optimizer_name=kwargs.pop('opt', 'sgd'),
|
||||
learning_rate=kwargs.pop('lr', 1e-4))
|
||||
|
||||
def _gen_target(self, batch_size):
|
||||
return torch.empty(
|
||||
(batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes)
|
||||
|
||||
def run(self):
|
||||
def _step(detail=False):
|
||||
self.optimizer.zero_grad() # can this be ignored?
|
||||
t_start = self.time_fn()
|
||||
t_fwd_end = t_start
|
||||
t_bwd_end = t_start
|
||||
with self.amp_autocast():
|
||||
output = self.model(self.example_inputs)
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
if detail:
|
||||
t_fwd_end = self.time_fn(True)
|
||||
target = self._gen_target(output.shape[0])
|
||||
self.loss(output, target).backward()
|
||||
if detail:
|
||||
t_bwd_end = self.time_fn(True)
|
||||
self.optimizer.step()
|
||||
t_end = self.time_fn(True)
|
||||
if detail:
|
||||
delta_fwd = t_fwd_end - t_start
|
||||
delta_bwd = t_bwd_end - t_fwd_end
|
||||
delta_opt = t_end - t_bwd_end
|
||||
return delta_fwd, delta_bwd, delta_opt
|
||||
else:
|
||||
delta_step = t_end - t_start
|
||||
return delta_step
|
||||
|
||||
_logger.info(
|
||||
f'Running train benchmark on {self.model_name} for {self.num_bench_iter} steps w/ '
|
||||
f'input size {self.input_size} and batch size {self.batch_size}.')
|
||||
|
||||
self._init_input()
|
||||
|
||||
for _ in range(self.num_warm_iter):
|
||||
_step()
|
||||
|
||||
t_run_start = self.time_fn()
|
||||
if self.detail:
|
||||
total_fwd = 0.
|
||||
total_bwd = 0.
|
||||
total_opt = 0.
|
||||
num_samples = 0
|
||||
for i in range(self.num_bench_iter):
|
||||
delta_fwd, delta_bwd, delta_opt = _step(True)
|
||||
num_samples += self.batch_size
|
||||
total_fwd += delta_fwd
|
||||
total_bwd += delta_bwd
|
||||
total_opt += delta_opt
|
||||
num_steps = (i + 1)
|
||||
if num_steps % self.log_freq == 0:
|
||||
total_step = total_fwd + total_bwd + total_opt
|
||||
_logger.info(
|
||||
f"Train [{num_steps}/{self.num_bench_iter}]."
|
||||
f" {num_samples / total_step:0.2f} samples/sec."
|
||||
f" {1000 * total_fwd / num_steps:0.3f} ms/step fwd,"
|
||||
f" {1000 * total_bwd / num_steps:0.3f} ms/step bwd,"
|
||||
f" {1000 * total_opt / num_steps:0.3f} ms/step opt."
|
||||
)
|
||||
total_step = total_fwd + total_bwd + total_opt
|
||||
t_run_elapsed = self.time_fn() - t_run_start
|
||||
results = dict(
|
||||
samples_per_sec=round(num_samples / t_run_elapsed, 2),
|
||||
step_time=round(1000 * total_step / self.num_bench_iter, 3),
|
||||
fwd_time=round(1000 * total_fwd / self.num_bench_iter, 3),
|
||||
bwd_time=round(1000 * total_bwd / self.num_bench_iter, 3),
|
||||
opt_time=round(1000 * total_opt / self.num_bench_iter, 3),
|
||||
batch_size=self.batch_size,
|
||||
img_size=self.input_size[-1],
|
||||
param_count=round(self.param_count / 1e6, 2),
|
||||
)
|
||||
else:
|
||||
total_step = 0.
|
||||
num_samples = 0
|
||||
for i in range(self.num_bench_iter):
|
||||
delta_step = _step(False)
|
||||
num_samples += self.batch_size
|
||||
total_step += delta_step
|
||||
num_steps = (i + 1)
|
||||
if num_steps % self.log_freq == 0:
|
||||
_logger.info(
|
||||
f"Train [{num_steps}/{self.num_bench_iter}]."
|
||||
f" {num_samples / total_step:0.2f} samples/sec."
|
||||
f" {1000 * total_step / num_steps:0.3f} ms/step.")
|
||||
t_run_elapsed = self.time_fn() - t_run_start
|
||||
results = dict(
|
||||
samples_per_sec=round(num_samples / t_run_elapsed, 2),
|
||||
step_time=round(1000 * total_step / self.num_bench_iter, 3),
|
||||
batch_size=self.batch_size,
|
||||
img_size=self.input_size[-1],
|
||||
param_count=round(self.param_count / 1e6, 2),
|
||||
)
|
||||
|
||||
_logger.info(
|
||||
f"Train benchmark of {self.model_name} done. "
|
||||
f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/sample")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def decay_batch_exp(batch_size, factor=0.5, divisor=16):
|
||||
out_batch_size = batch_size * factor
|
||||
if out_batch_size > divisor:
|
||||
out_batch_size = (out_batch_size + 1) // divisor * divisor
|
||||
else:
|
||||
out_batch_size = batch_size - 1
|
||||
return max(0, int(out_batch_size))
|
||||
|
||||
|
||||
def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
|
||||
batch_size = initial_batch_size
|
||||
results = dict()
|
||||
while batch_size >= 1:
|
||||
try:
|
||||
bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs)
|
||||
results = bench.run()
|
||||
return results
|
||||
except RuntimeError as e:
|
||||
torch.cuda.empty_cache()
|
||||
batch_size = decay_batch_exp(batch_size)
|
||||
print(f'Error: {str(e)} while running benchmark. Reducing batch size to {batch_size} for retry.')
|
||||
return results
|
||||
|
||||
|
||||
def benchmark(args):
|
||||
if args.amp:
|
||||
_logger.warning("Overriding precision to 'amp' since --amp flag set.")
|
||||
args.precision = 'amp'
|
||||
_logger.info(f'Benchmarking in {args.precision} precision. '
|
||||
f'{"NHWC" if args.channels_last else "NCHW"} layout. '
|
||||
f'torchscript {"enabled" if args.torchscript else "disabled"}')
|
||||
|
||||
bench_kwargs = vars(args).copy()
|
||||
bench_kwargs.pop('amp')
|
||||
model = bench_kwargs.pop('model')
|
||||
batch_size = bench_kwargs.pop('batch_size')
|
||||
|
||||
bench_fns = (InferenceBenchmarkRunner,)
|
||||
prefixes = ('infer',)
|
||||
if args.bench == 'both':
|
||||
bench_fns = (
|
||||
InferenceBenchmarkRunner,
|
||||
TrainBenchmarkRunner
|
||||
)
|
||||
prefixes = ('infer', 'train')
|
||||
elif args.bench == 'train':
|
||||
bench_fns = TrainBenchmarkRunner,
|
||||
prefixes = 'train',
|
||||
|
||||
model_results = OrderedDict(model=model)
|
||||
for prefix, bench_fn in zip(prefixes, bench_fns):
|
||||
run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs)
|
||||
if prefix:
|
||||
run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
|
||||
model_results.update(run_results)
|
||||
param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0))
|
||||
model_results.setdefault('param_count', param_count)
|
||||
model_results.pop('train_param_count', 0)
|
||||
return model_results
|
||||
|
||||
|
||||
def main():
|
||||
setup_default_logging()
|
||||
args = parser.parse_args()
|
||||
model_cfgs = []
|
||||
model_names = []
|
||||
|
||||
if args.model_list:
|
||||
args.model = ''
|
||||
with open(args.model_list) as f:
|
||||
model_names = [line.rstrip() for line in f]
|
||||
model_cfgs = [(n, None) for n in model_names]
|
||||
elif args.model == 'all':
|
||||
# validate all models in a list of names with pretrained checkpoints
|
||||
args.pretrained = True
|
||||
model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
|
||||
model_cfgs = [(n, None) for n in model_names]
|
||||
elif not is_model(args.model):
|
||||
# model name doesn't exist, try as wildcard filter
|
||||
model_names = list_models(args.model)
|
||||
model_cfgs = [(n, None) for n in model_names]
|
||||
|
||||
if len(model_cfgs):
|
||||
results_file = args.results_file or './benchmark.csv'
|
||||
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
|
||||
results = []
|
||||
try:
|
||||
for m, _ in model_cfgs:
|
||||
if not m:
|
||||
continue
|
||||
args.model = m
|
||||
r = benchmark(args)
|
||||
results.append(r)
|
||||
except KeyboardInterrupt as e:
|
||||
pass
|
||||
sort_key = 'train_samples_per_sec' if 'train' in args.bench else 'infer_samples_per_sec'
|
||||
results = sorted(results, key=lambda x: x[sort_key], reverse=True)
|
||||
if len(results):
|
||||
write_results(results_file, results)
|
||||
|
||||
import json
|
||||
json_str = json.dumps(results, indent=4)
|
||||
print(json_str)
|
||||
else:
|
||||
benchmark(args)
|
||||
|
||||
|
||||
def write_results(results_file, results):
|
||||
with open(results_file, mode='w') as cf:
|
||||
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
|
||||
dw.writeheader()
|
||||
for r in results:
|
||||
dw.writerow(r)
|
||||
cf.flush()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,388 @@
|
||||
""" Pooling-based Vision Transformer (PiT) in PyTorch
|
||||
|
||||
A PyTorch implement of Pooling-based Vision Transformers as described in
|
||||
'Rethinking Spatial Dimensions of Vision Transformers' - https://arxiv.org/abs/2103.16302
|
||||
|
||||
This code was adapted from the original version at https://github.com/naver-ai/pit, original copyright below.
|
||||
|
||||
Modifications for timm by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
# PiT
|
||||
# Copyright 2021-present NAVER Corp.
|
||||
# Apache License v2.0
|
||||
|
||||
import math
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .layers import trunc_normal_, to_2tuple
|
||||
from .registry import register_model
|
||||
from .vision_transformer import Block
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.conv', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# deit models (FB weights)
|
||||
'pit_ti_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_730.pth'),
|
||||
'pit_xs_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_781.pth'),
|
||||
'pit_s_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_809.pth'),
|
||||
'pit_b_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_820.pth'),
|
||||
'pit_ti_distilled_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'pit_xs_distilled_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'pit_s_distilled_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'pit_b_distilled_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
}
|
||||
|
||||
|
||||
class SequentialTuple(nn.Sequential):
|
||||
""" This module exists to work around torchscript typing issues list -> list"""
|
||||
def __init__(self, *args):
|
||||
super(SequentialTuple, self).__init__(*args)
|
||||
|
||||
def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
for module in self:
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self, base_dim, depth, heads, mlp_ratio, pool=None, drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None):
|
||||
super(Transformer, self).__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
embed_dim = base_dim * heads
|
||||
|
||||
self.blocks = nn.Sequential(*[
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
num_heads=heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=True,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=drop_path_prob[i],
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6)
|
||||
)
|
||||
for i in range(depth)])
|
||||
|
||||
self.pool = pool
|
||||
|
||||
def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x, cls_tokens = x
|
||||
B, C, H, W = x.shape
|
||||
token_length = cls_tokens.shape[1]
|
||||
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
x = self.blocks(x)
|
||||
|
||||
cls_tokens = x[:, :token_length]
|
||||
x = x[:, token_length:]
|
||||
x = x.transpose(1, 2).reshape(B, C, H, W)
|
||||
|
||||
if self.pool is not None:
|
||||
x, cls_tokens = self.pool(x, cls_tokens)
|
||||
return x, cls_tokens
|
||||
|
||||
|
||||
class ConvHeadPooling(nn.Module):
|
||||
def __init__(self, in_feature, out_feature, stride, padding_mode='zeros'):
|
||||
super(ConvHeadPooling, self).__init__()
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_feature, out_feature, kernel_size=stride + 1, padding=stride // 2, stride=stride,
|
||||
padding_mode=padding_mode, groups=in_feature)
|
||||
self.fc = nn.Linear(in_feature, out_feature)
|
||||
|
||||
def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
x = self.conv(x)
|
||||
cls_token = self.fc(cls_token)
|
||||
|
||||
return x, cls_token
|
||||
|
||||
|
||||
class ConvEmbedding(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, patch_size, stride, padding):
|
||||
super(ConvEmbedding, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=patch_size, stride=stride, padding=padding, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class PoolingVisionTransformer(nn.Module):
|
||||
""" Pooling-based Vision Transformer
|
||||
|
||||
A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers'
|
||||
- https://arxiv.org/abs/2103.16302
|
||||
"""
|
||||
def __init__(self, img_size, patch_size, stride, base_dims, depth, heads,
|
||||
mlp_ratio, num_classes=1000, in_chans=3, distilled=False,
|
||||
attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
|
||||
super(PoolingVisionTransformer, self).__init__()
|
||||
|
||||
padding = 0
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
height = math.floor((img_size[0] + 2 * padding - patch_size[0]) / stride + 1)
|
||||
width = math.floor((img_size[1] + 2 * padding - patch_size[1]) / stride + 1)
|
||||
|
||||
self.base_dims = base_dims
|
||||
self.heads = heads
|
||||
self.num_classes = num_classes
|
||||
self.num_tokens = 2 if distilled else 1
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, base_dims[0] * heads[0], height, width))
|
||||
self.patch_embed = ConvEmbedding(in_chans, base_dims[0] * heads[0], patch_size, stride, padding)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, base_dims[0] * heads[0]))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
transformers = []
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depth)).split(depth)]
|
||||
for stage in range(len(depth)):
|
||||
pool = None
|
||||
if stage < len(heads) - 1:
|
||||
pool = ConvHeadPooling(
|
||||
base_dims[stage] * heads[stage], base_dims[stage + 1] * heads[stage + 1], stride=2)
|
||||
transformers += [Transformer(
|
||||
base_dims[stage], depth[stage], heads[stage], mlp_ratio, pool=pool,
|
||||
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_prob=dpr[stage])
|
||||
]
|
||||
self.transformers = SequentialTuple(*transformers)
|
||||
self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6)
|
||||
self.embed_dim = base_dims[-1] * heads[-1]
|
||||
|
||||
# Classifier head
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
|
||||
if num_classes > 0 and distilled else nn.Identity()
|
||||
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
|
||||
if num_classes > 0 and self.num_tokens == 2 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
x, cls_tokens = self.transformers((x, cls_tokens))
|
||||
cls_tokens = self.norm(cls_tokens)
|
||||
return cls_tokens
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x_cls = self.head(x[:, 0])
|
||||
if self.num_tokens > 1:
|
||||
x_dist = self.head_dist(x[:, 1])
|
||||
if self.training and not torch.jit.is_scripting():
|
||||
return x_cls, x_dist
|
||||
else:
|
||||
return (x_cls + x_dist) / 2
|
||||
else:
|
||||
return x_cls
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" preprocess checkpoints """
|
||||
out_dict = {}
|
||||
p_blocks = re.compile(r'pools\.(\d)\.')
|
||||
for k, v in state_dict.items():
|
||||
# FIXME need to update resize for PiT impl
|
||||
# if k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
||||
# # To resize pos embedding when using model at different size from pretrained weights
|
||||
# v = resize_pos_embed(v, model.pos_embed)
|
||||
k = p_blocks.sub(lambda exp: f'transformers.{int(exp.group(1))}.pool.', k)
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_pit(variant, pretrained=False, **kwargs):
|
||||
default_cfg = deepcopy(default_cfgs[variant])
|
||||
overlay_external_default_cfg(default_cfg, kwargs)
|
||||
default_num_classes = default_cfg['num_classes']
|
||||
default_img_size = default_cfg['input_size'][-2:]
|
||||
img_size = kwargs.pop('img_size', default_img_size)
|
||||
num_classes = kwargs.pop('num_classes', default_num_classes)
|
||||
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
|
||||
model = build_model_with_cfg(
|
||||
PoolingVisionTransformer, variant, pretrained,
|
||||
default_cfg=default_cfg,
|
||||
img_size=img_size,
|
||||
num_classes=num_classes,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_b_224(pretrained, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=14,
|
||||
stride=7,
|
||||
base_dims=[64, 64, 64],
|
||||
depth=[3, 6, 4],
|
||||
heads=[4, 8, 16],
|
||||
mlp_ratio=4,
|
||||
**kwargs
|
||||
)
|
||||
return _create_pit('pit_b_224', pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_s_224(pretrained, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16,
|
||||
stride=8,
|
||||
base_dims=[48, 48, 48],
|
||||
depth=[2, 6, 4],
|
||||
heads=[3, 6, 12],
|
||||
mlp_ratio=4,
|
||||
**kwargs
|
||||
)
|
||||
return _create_pit('pit_s_224', pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_xs_224(pretrained, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16,
|
||||
stride=8,
|
||||
base_dims=[48, 48, 48],
|
||||
depth=[2, 6, 4],
|
||||
heads=[2, 4, 8],
|
||||
mlp_ratio=4,
|
||||
**kwargs
|
||||
)
|
||||
return _create_pit('pit_xs_224', pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_ti_224(pretrained, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16,
|
||||
stride=8,
|
||||
base_dims=[32, 32, 32],
|
||||
depth=[2, 6, 4],
|
||||
heads=[2, 4, 8],
|
||||
mlp_ratio=4,
|
||||
**kwargs
|
||||
)
|
||||
return _create_pit('pit_ti_224', pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_b_distilled_224(pretrained, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=14,
|
||||
stride=7,
|
||||
base_dims=[64, 64, 64],
|
||||
depth=[3, 6, 4],
|
||||
heads=[4, 8, 16],
|
||||
mlp_ratio=4,
|
||||
distilled=True,
|
||||
**kwargs
|
||||
)
|
||||
return _create_pit('pit_b_distilled_224', pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_s_distilled_224(pretrained, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16,
|
||||
stride=8,
|
||||
base_dims=[48, 48, 48],
|
||||
depth=[2, 6, 4],
|
||||
heads=[3, 6, 12],
|
||||
mlp_ratio=4,
|
||||
distilled=True,
|
||||
**kwargs
|
||||
)
|
||||
return _create_pit('pit_s_distilled_224', pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_xs_distilled_224(pretrained, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16,
|
||||
stride=8,
|
||||
base_dims=[48, 48, 48],
|
||||
depth=[2, 6, 4],
|
||||
heads=[2, 4, 8],
|
||||
mlp_ratio=4,
|
||||
distilled=True,
|
||||
**kwargs
|
||||
)
|
||||
return _create_pit('pit_xs_distilled_224', pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_ti_distilled_224(pretrained, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16,
|
||||
stride=8,
|
||||
base_dims=[32, 32, 32],
|
||||
depth=[2, 6, 4],
|
||||
heads=[2, 4, 8],
|
||||
mlp_ratio=4,
|
||||
distilled=True,
|
||||
**kwargs
|
||||
)
|
||||
return _create_pit('pit_ti_distilled_224', pretrained, **model_kwargs)
|
@ -0,0 +1,313 @@
|
||||
""" Hybrid Vision Transformer (ViT) in PyTorch
|
||||
|
||||
A PyTorch implement of the Hybrid Vision Transformers as described in
|
||||
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
|
||||
- https://arxiv.org/abs/2010.11929
|
||||
|
||||
NOTE This relies on code in vision_transformer.py. The hybrid model definitions were moved here to
|
||||
keep file sizes sane.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .layers import StdConv2dSame, StdConv2d, to_2tuple
|
||||
from .resnet import resnet26d, resnet50d
|
||||
from .resnetv2 import ResNetV2, create_resnetv2_stem
|
||||
from .registry import register_model
|
||||
from timm.models.vision_transformer import _create_vision_transformer
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
||||
'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# hybrid in-21k models (weights ported from official Google JAX impl where they exist)
|
||||
'vit_base_r50_s16_224_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
|
||||
num_classes=21843, crop_pct=0.9),
|
||||
|
||||
# hybrid in-1k models (weights ported from official JAX impl)
|
||||
'vit_base_r50_s16_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
# hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones)
|
||||
'vit_tiny_r_s16_p8_224': _cfg(),
|
||||
'vit_small_r_s16_p8_224': _cfg(),
|
||||
'vit_small_r20_s16_p2_224': _cfg(),
|
||||
'vit_small_r20_s16_224': _cfg(),
|
||||
'vit_small_r26_s32_224': _cfg(),
|
||||
'vit_base_r20_s16_224': _cfg(),
|
||||
'vit_base_r26_s32_224': _cfg(),
|
||||
'vit_base_r50_s16_224': _cfg(),
|
||||
'vit_large_r50_s32_224': _cfg(),
|
||||
|
||||
# hybrid models (using timm resnet backbones)
|
||||
'vit_small_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'vit_small_resnet50d_s16_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'vit_base_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'vit_base_resnet50d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
}
|
||||
|
||||
|
||||
class HybridEmbed(nn.Module):
|
||||
""" CNN Feature Map Embedding
|
||||
Extract feature map from CNN, flatten, project to embedding dim.
|
||||
"""
|
||||
def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
assert isinstance(backbone, nn.Module)
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.backbone = backbone
|
||||
if feature_size is None:
|
||||
with torch.no_grad():
|
||||
# NOTE Most reliable way of determining output dims is to run forward pass
|
||||
training = backbone.training
|
||||
if training:
|
||||
backbone.eval()
|
||||
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
||||
if isinstance(o, (list, tuple)):
|
||||
o = o[-1] # last feature if backbone outputs list/tuple of features
|
||||
feature_size = o.shape[-2:]
|
||||
feature_dim = o.shape[1]
|
||||
backbone.train(training)
|
||||
else:
|
||||
feature_size = to_2tuple(feature_size)
|
||||
if hasattr(self.backbone, 'feature_info'):
|
||||
feature_dim = self.backbone.feature_info.channels()[-1]
|
||||
else:
|
||||
feature_dim = self.backbone.num_features
|
||||
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
||||
self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1]
|
||||
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
|
||||
default_cfg = deepcopy(default_cfgs[variant])
|
||||
embed_layer = partial(HybridEmbed, backbone=backbone)
|
||||
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
||||
return _create_vision_transformer(
|
||||
variant, pretrained=pretrained, default_cfg=default_cfg, embed_layer=embed_layer, **kwargs)
|
||||
|
||||
|
||||
def _resnetv2(layers=(3, 4, 9), **kwargs):
|
||||
""" ResNet-V2 backbone helper"""
|
||||
padding_same = kwargs.get('padding_same', True)
|
||||
if padding_same:
|
||||
stem_type = 'same'
|
||||
conv_layer = StdConv2dSame
|
||||
else:
|
||||
stem_type = ''
|
||||
conv_layer = StdConv2d
|
||||
if len(layers):
|
||||
backbone = ResNetV2(
|
||||
layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
|
||||
preact=False, stem_type=stem_type, conv_layer=conv_layer)
|
||||
else:
|
||||
backbone = create_resnetv2_stem(
|
||||
kwargs.get('in_chans', 3), stem_type=stem_type, preact=False, conv_layer=conv_layer)
|
||||
return backbone
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
|
||||
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
|
||||
# NOTE this is forwarding to model def above for backwards compatibility
|
||||
return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_r50_s16_384(pretrained=False, **kwargs):
|
||||
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
backbone = _resnetv2((3, 4, 9), **kwargs)
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet50_384(pretrained=False, **kwargs):
|
||||
# NOTE this is forwarding to model def above for backwards compatibility
|
||||
return vit_base_r50_s16_384(pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
|
||||
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
||||
"""
|
||||
backbone = _resnetv2(layers=(), **kwargs)
|
||||
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_r_s16_p8_224(pretrained=False, **kwargs):
|
||||
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
||||
"""
|
||||
backbone = _resnetv2(layers=(), **kwargs)
|
||||
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_small_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_r20_s16_p2_224(pretrained=False, **kwargs):
|
||||
""" R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224.
|
||||
"""
|
||||
backbone = _resnetv2((2, 4), **kwargs)
|
||||
model_kwargs = dict(patch_size=2, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_small_r20_s16_p2_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_r20_s16_224(pretrained=False, **kwargs):
|
||||
""" R20+ViT-S/S16 hybrid.
|
||||
"""
|
||||
backbone = _resnetv2((2, 2, 2), **kwargs)
|
||||
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_small_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_r26_s32_224(pretrained=False, **kwargs):
|
||||
""" R26+ViT-S/S32 hybrid.
|
||||
"""
|
||||
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
||||
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_small_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_r20_s16_224(pretrained=False, **kwargs):
|
||||
""" R20+ViT-B/S16 hybrid.
|
||||
"""
|
||||
backbone = _resnetv2((2, 2, 2), **kwargs)
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_base_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_r26_s32_224(pretrained=False, **kwargs):
|
||||
""" R26+ViT-B/S32 hybrid.
|
||||
"""
|
||||
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_base_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_r50_s16_224(pretrained=False, **kwargs):
|
||||
""" R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
||||
"""
|
||||
backbone = _resnetv2((3, 4, 9), **kwargs)
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_base_r50_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_r50_s32_224(pretrained=False, **kwargs):
|
||||
""" R50+ViT-L/S32 hybrid.
|
||||
"""
|
||||
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_resnet26d_224(pretrained=False, **kwargs):
|
||||
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
|
||||
"""
|
||||
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
||||
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_small_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_resnet50d_s16_224(pretrained=False, **kwargs):
|
||||
""" Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
|
||||
"""
|
||||
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3])
|
||||
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_small_resnet50d_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet26d_224(pretrained=False, **kwargs):
|
||||
""" Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
|
||||
"""
|
||||
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_base_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet50d_224(pretrained=False, **kwargs):
|
||||
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
|
||||
"""
|
||||
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
@ -1 +1 @@
|
||||
__version__ = '0.4.6'
|
||||
__version__ = '0.4.7'
|
||||
|
Loading…
Reference in new issue