Compare commits

...

3 Commits

@ -22,7 +22,7 @@ from timm.data import resolve_data_config
from timm.layers import set_fast_norm from timm.layers import set_fast_norm
from timm.models import create_model, is_model, list_models from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2 from timm.optim import create_optimizer_v2
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs
has_apex = False has_apex = False
try: try:
@ -108,12 +108,15 @@ parser.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages') help='Enable gradient checkpointing through model blocks/stages')
parser.add_argument('--amp', action='store_true', default=False, parser.add_argument('--amp', action='store_true', default=False,
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.') help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
parser.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16). Overrides --precision arg if args.amp True.')
parser.add_argument('--precision', default='float32', type=str, parser.add_argument('--precision', default='float32', type=str,
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)') help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
parser.add_argument('--fuser', default='', type=str, parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true', parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm') help='enable experimental fast-norm')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
# codegen (model compilation) options # codegen (model compilation) options
scripting_group = parser.add_mutually_exclusive_group() scripting_group = parser.add_mutually_exclusive_group()
@ -124,7 +127,6 @@ scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None
scripting_group.add_argument('--aot-autograd', default=False, action='store_true', scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd optimization.") help="Enable AOT Autograd optimization.")
# train optimizer parameters # train optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"') help='Optimizer (default: "sgd"')
@ -168,19 +170,21 @@ def count_params(model: nn.Module):
def resolve_precision(precision: str): def resolve_precision(precision: str):
assert precision in ('amp', 'float16', 'bfloat16', 'float32') assert precision in ('amp', 'amp_bfloat16', 'float16', 'bfloat16', 'float32')
use_amp = False amp_dtype = None # amp disabled
model_dtype = torch.float32 model_dtype = torch.float32
data_dtype = torch.float32 data_dtype = torch.float32
if precision == 'amp': if precision == 'amp':
use_amp = True amp_dtype = torch.float16
elif precision == 'amp_bfloat16':
amp_dtype = torch.bfloat16
elif precision == 'float16': elif precision == 'float16':
model_dtype = torch.float16 model_dtype = torch.float16
data_dtype = torch.float16 data_dtype = torch.float16
elif precision == 'bfloat16': elif precision == 'bfloat16':
model_dtype = torch.bfloat16 model_dtype = torch.bfloat16
data_dtype = torch.bfloat16 data_dtype = torch.bfloat16
return use_amp, model_dtype, data_dtype return amp_dtype, model_dtype, data_dtype
def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False): def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
@ -228,9 +232,12 @@ class BenchmarkRunner:
self.model_name = model_name self.model_name = model_name
self.detail = detail self.detail = detail
self.device = device self.device = device
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision) self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.channels_last = kwargs.pop('channels_last', False) self.channels_last = kwargs.pop('channels_last', False)
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=torch.float16) if self.use_amp else suppress if self.amp_dtype is not None:
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype)
else:
self.amp_autocast = suppress
if fuser: if fuser:
set_jit_fuser(fuser) set_jit_fuser(fuser)
@ -243,6 +250,7 @@ class BenchmarkRunner:
drop_rate=kwargs.pop('drop', 0.), drop_rate=kwargs.pop('drop', 0.),
drop_path_rate=kwargs.pop('drop_path', None), drop_path_rate=kwargs.pop('drop_path', None),
drop_block_rate=kwargs.pop('drop_block', None), drop_block_rate=kwargs.pop('drop_block', None),
**kwargs.pop('model_kwargs', {}),
) )
self.model.to( self.model.to(
device=self.device, device=self.device,
@ -560,7 +568,7 @@ def _try_run(
def benchmark(args): def benchmark(args):
if args.amp: if args.amp:
_logger.warning("Overriding precision to 'amp' since --amp flag set.") _logger.warning("Overriding precision to 'amp' since --amp flag set.")
args.precision = 'amp' args.precision = 'amp' if args.amp_dtype == 'float16' else '_'.join(['amp', args.amp_dtype])
_logger.info(f'Benchmarking in {args.precision} precision. ' _logger.info(f'Benchmarking in {args.precision} precision. '
f'{"NHWC" if args.channels_last else "NCHW"} layout. ' f'{"NHWC" if args.channels_last else "NCHW"} layout. '
f'torchscript {"enabled" if args.torchscript else "disabled"}') f'torchscript {"enabled" if args.torchscript else "disabled"}')

@ -20,7 +20,7 @@ import torch
from timm.data import create_dataset, create_loader, resolve_data_config from timm.data import create_dataset, create_loader, resolve_data_config
from timm.layers import apply_test_time_pool from timm.layers import apply_test_time_pool
from timm.models import create_model from timm.models import create_model
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
try: try:
from apex import amp from apex import amp
@ -72,6 +72,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)') metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int, parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty') metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)')
parser.add_argument('--input-size', default=None, nargs=3, type=int, parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--use-train-size', action='store_true', default=False, parser.add_argument('--use-train-size', action='store_true', default=False,
@ -110,6 +112,7 @@ parser.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)') help='lower precision AMP dtype (default: float16)')
parser.add_argument('--fuser', default='', type=str, parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
scripting_group = parser.add_mutually_exclusive_group() scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true', scripting_group.add_argument('--torchscript', default=False, action='store_true',
@ -170,12 +173,19 @@ def main():
set_jit_fuser(args.fuser) set_jit_fuser(args.fuser)
# create model # create model
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
model = create_model( model = create_model(
args.model, args.model,
num_classes=args.num_classes, num_classes=args.num_classes,
in_chans=3, in_chans=in_chans,
pretrained=args.pretrained, pretrained=args.pretrained,
checkpoint_path=args.checkpoint, checkpoint_path=args.checkpoint,
**args.model_kwargs,
) )
if args.num_classes is None: if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'

@ -209,6 +209,7 @@ def push_to_hf_hub(
private: bool = False, private: bool = False,
create_pr: bool = False, create_pr: bool = False,
model_config: Optional[dict] = None, model_config: Optional[dict] = None,
model_card: Optional[dict] = None,
): ):
# Create repo if it doesn't exist yet # Create repo if it doesn't exist yet
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
@ -232,9 +233,23 @@ def push_to_hf_hub(
# Add readme if it does not exist # Add readme if it does not exist
if not has_readme: if not has_readme:
model_card = model_card or {}
model_name = repo_id.split('/')[-1] model_name = repo_id.split('/')[-1]
readme_path = Path(tmpdir) / "README.md" readme_path = Path(tmpdir) / "README.md"
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}' readme_text = "---\n"
readme_text += "tags:\n- image-classification\n- timm\n"
readme_text += "library_tag: timm\n"
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
readme_text += "---\n"
readme_text += f"# Model card for {model_name}\n"
if 'description' in model_card:
readme_text += f"\n{model_card['description']}\n"
if 'details' in model_card:
readme_text += f"\n## Model Details\n"
for k, v in model_card['details'].items():
readme_text += f"- **{k}:** {v}\n"
if 'citation' in model_card:
readme_text += f"\n## Citation\n```\n{model_card['citation']}```\n"
readme_path.write_text(readme_text) readme_path.write_text(readme_text)
# Upload model and return # Upload model and return

@ -218,7 +218,10 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
def interleave_blocks( def interleave_blocks(
types: Tuple[str, str], d, every: Union[int, List[int]] = 1, first: bool = False, **kwargs types: Tuple[str, str], d,
every: Union[int, List[int]] = 1,
first: bool = False,
**kwargs,
) -> Tuple[ByoBlockCfg]: ) -> Tuple[ByoBlockCfg]:
""" interleave 2 block types in stack """ interleave 2 block types in stack
""" """
@ -1587,15 +1590,32 @@ class ByobNet(nn.Module):
in_chans=3, in_chans=3,
global_pool='avg', global_pool='avg',
output_stride=32, output_stride=32,
zero_init_last=True,
img_size=None, img_size=None,
drop_rate=0., drop_rate=0.,
drop_path_rate=0., drop_path_rate=0.,
zero_init_last=True,
**kwargs,
): ):
"""
Args:
cfg (ByoModelCfg): Model architecture configuration
num_classes (int): Number of classifier classes (default: 1000)
in_chans (int): Number of input channels (default: 3)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
img_size (Union[int, Tuple[int]): Image size for fixed image size models (i.e. self-attn)
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
zero_init_last (bool): Zero-init last weight of residual path
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.grad_checkpointing = False self.grad_checkpointing = False
cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
layers = get_layer_fns(cfg) layers = get_layer_fns(cfg)
if cfg.fixed_input_size: if cfg.fixed_input_size:
assert img_size is not None, 'img_size argument is required for fixed input size model' assert img_size is not None, 'img_size argument is required for fixed input size model'

@ -167,7 +167,7 @@ class ConvNeXtStage(nn.Module):
conv_bias=conv_bias, conv_bias=conv_bias,
use_grn=use_grn, use_grn=use_grn,
act_layer=act_layer, act_layer=act_layer,
norm_layer=norm_layer if conv_mlp else norm_layer_cl norm_layer=norm_layer if conv_mlp else norm_layer_cl,
)) ))
in_chs = out_chs in_chs = out_chs
self.blocks = nn.Sequential(*stage_blocks) self.blocks = nn.Sequential(*stage_blocks)
@ -184,16 +184,6 @@ class ConvNeXtStage(nn.Module):
class ConvNeXt(nn.Module): class ConvNeXt(nn.Module):
r""" ConvNeXt r""" ConvNeXt
A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_rate (float): Head dropout rate
drop_path_rate (float): Stochastic depth rate. Default: 0.
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
""" """
def __init__( def __init__(
@ -218,6 +208,28 @@ class ConvNeXt(nn.Module):
drop_rate=0., drop_rate=0.,
drop_path_rate=0., drop_path_rate=0.,
): ):
"""
Args:
in_chans (int): Number of input image channels (default: 3)
num_classes (int): Number of classes for classification head (default: 1000)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
depths (tuple(int)): Number of blocks at each stage. (default: [3, 3, 9, 3])
dims (tuple(int)): Feature dimension at each stage. (default: [96, 192, 384, 768])
kernel_sizes (Union[int, List[int]]: Depthwise convolution kernel-sizes for each stage (default: 7)
ls_init_value (float): Init value for Layer Scale (default: 1e-6)
stem_type (str): Type of stem (default: 'patch')
patch_size (int): Stem patch size for patch stem (default: 4)
head_init_scale (float): Init scaling value for classifier weights and biases (default: 1)
head_norm_first (bool): Apply normalization before global pool + head (default: False)
conv_mlp (bool): Use 1x1 conv in MLP, improves speed for small networks w/ chan last (default: False)
conv_bias (bool): Use bias layers w/ all convolutions (default: True)
use_grn (bool): Use Global Response Norm (ConvNeXt-V2) in MLP (default: False)
act_layer (Union[str, nn.Module]): Activation Layer
norm_layer (Union[str, nn.Module]): Normalization Layer
drop_rate (float): Head dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth rate (default: 0.)
"""
super().__init__() super().__init__()
assert output_stride in (8, 16, 32) assert output_stride in (8, 16, 32)
kernel_sizes = to_ntuple(4)(kernel_sizes) kernel_sizes = to_ntuple(4)(kernel_sizes)
@ -279,7 +291,7 @@ class ConvNeXt(nn.Module):
use_grn=use_grn, use_grn=use_grn,
act_layer=act_layer, act_layer=act_layer,
norm_layer=norm_layer, norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl norm_layer_cl=norm_layer_cl,
)) ))
prev_chs = out_chs prev_chs = out_chs
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2

@ -12,7 +12,7 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict, replace
from functools import partial from functools import partial
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
@ -518,7 +518,7 @@ class CrossStage(nn.Module):
cross_linear=False, cross_linear=False,
block_dpr=None, block_dpr=None,
block_fn=BottleneckBlock, block_fn=BottleneckBlock,
**block_kwargs **block_kwargs,
): ):
super(CrossStage, self).__init__() super(CrossStage, self).__init__()
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
@ -558,7 +558,7 @@ class CrossStage(nn.Module):
bottle_ratio=bottle_ratio, bottle_ratio=bottle_ratio,
groups=groups, groups=groups,
drop_path=block_dpr[i] if block_dpr is not None else 0., drop_path=block_dpr[i] if block_dpr is not None else 0.,
**block_kwargs **block_kwargs,
)) ))
prev_chs = block_out_chs prev_chs = block_out_chs
@ -597,7 +597,7 @@ class CrossStage3(nn.Module):
cross_linear=False, cross_linear=False,
block_dpr=None, block_dpr=None,
block_fn=BottleneckBlock, block_fn=BottleneckBlock,
**block_kwargs **block_kwargs,
): ):
super(CrossStage3, self).__init__() super(CrossStage3, self).__init__()
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
@ -635,7 +635,7 @@ class CrossStage3(nn.Module):
bottle_ratio=bottle_ratio, bottle_ratio=bottle_ratio,
groups=groups, groups=groups,
drop_path=block_dpr[i] if block_dpr is not None else 0., drop_path=block_dpr[i] if block_dpr is not None else 0.,
**block_kwargs **block_kwargs,
)) ))
prev_chs = block_out_chs prev_chs = block_out_chs
@ -668,7 +668,7 @@ class DarkStage(nn.Module):
avg_down=False, avg_down=False,
block_fn=BottleneckBlock, block_fn=BottleneckBlock,
block_dpr=None, block_dpr=None,
**block_kwargs **block_kwargs,
): ):
super(DarkStage, self).__init__() super(DarkStage, self).__init__()
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
@ -715,7 +715,7 @@ def create_csp_stem(
padding='', padding='',
act_layer=nn.ReLU, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, norm_layer=nn.BatchNorm2d,
aa_layer=None aa_layer=None,
): ):
stem = nn.Sequential() stem = nn.Sequential()
feature_info = [] feature_info = []
@ -738,7 +738,7 @@ def create_csp_stem(
stride=conv_stride, stride=conv_stride,
padding=padding if i == 0 else '', padding=padding if i == 0 else '',
act_layer=act_layer, act_layer=act_layer,
norm_layer=norm_layer norm_layer=norm_layer,
)) ))
stem_stride *= conv_stride stem_stride *= conv_stride
prev_chs = chs prev_chs = chs
@ -800,7 +800,7 @@ def create_csp_stages(
cfg: CspModelCfg, cfg: CspModelCfg,
drop_path_rate: float, drop_path_rate: float,
output_stride: int, output_stride: int,
stem_feat: Dict[str, Any] stem_feat: Dict[str, Any],
): ):
cfg_dict = asdict(cfg.stages) cfg_dict = asdict(cfg.stages)
num_stages = len(cfg.stages.depth) num_stages = len(cfg.stages.depth)
@ -868,12 +868,27 @@ class CspNet(nn.Module):
global_pool='avg', global_pool='avg',
drop_rate=0., drop_rate=0.,
drop_path_rate=0., drop_path_rate=0.,
zero_init_last=True zero_init_last=True,
**kwargs,
): ):
"""
Args:
cfg (CspModelCfg): Model architecture configuration
in_chans (int): Number of input channels (default: 3)
num_classes (int): Number of classifier classes (default: 1000)
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
global_pool (str): Global pooling type (default: 'avg')
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
zero_init_last (bool): Zero-init last weight of residual path
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
assert output_stride in (8, 16, 32) assert output_stride in (8, 16, 32)
cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
layer_args = dict( layer_args = dict(
act_layer=cfg.act_layer, act_layer=cfg.act_layer,
norm_layer=cfg.norm_layer, norm_layer=cfg.norm_layer,

@ -17,7 +17,7 @@ Status:
Hacked together by / copyright Ross Wightman, 2021. Hacked together by / copyright Ross Wightman, 2021.
""" """
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass from dataclasses import dataclass, replace
from functools import partial from functools import partial
from typing import Tuple, Optional from typing import Tuple, Optional
@ -159,11 +159,25 @@ class NfCfg:
def _nfres_cfg( def _nfres_cfg(
depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None): depths,
channels=(256, 512, 1024, 2048),
group_size=None,
act_layer='relu',
attn_layer=None,
attn_kwargs=None,
):
attn_kwargs = attn_kwargs or {} attn_kwargs = attn_kwargs or {}
cfg = NfCfg( cfg = NfCfg(
depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25, depths=depths,
group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs) channels=channels,
stem_type='7x7_pool',
stem_chs=64,
bottle_ratio=0.25,
group_size=group_size,
act_layer=act_layer,
attn_layer=attn_layer,
attn_kwargs=attn_kwargs,
)
return cfg return cfg
@ -171,28 +185,70 @@ def _nfreg_cfg(depths, channels=(48, 104, 208, 440)):
num_features = 1280 * channels[-1] // 440 num_features = 1280 * channels[-1] // 440
attn_kwargs = dict(rd_ratio=0.5) attn_kwargs = dict(rd_ratio=0.5)
cfg = NfCfg( cfg = NfCfg(
depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25, depths=depths,
num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs) channels=channels,
stem_type='3x3',
group_size=8,
width_factor=0.75,
bottle_ratio=2.25,
num_features=num_features,
reg=True,
attn_layer='se',
attn_kwargs=attn_kwargs,
)
return cfg return cfg
def _nfnet_cfg( def _nfnet_cfg(
depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2., depths,
act_layer='gelu', attn_layer='se', attn_kwargs=None): channels=(256, 512, 1536, 1536),
group_size=128,
bottle_ratio=0.5,
feat_mult=2.,
act_layer='gelu',
attn_layer='se',
attn_kwargs=None,
):
num_features = int(channels[-1] * feat_mult) num_features = int(channels[-1] * feat_mult)
attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5) attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5)
cfg = NfCfg( cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size, depths=depths,
bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer, channels=channels,
attn_layer=attn_layer, attn_kwargs=attn_kwargs) stem_type='deep_quad',
stem_chs=128,
group_size=group_size,
bottle_ratio=bottle_ratio,
extra_conv=True,
num_features=num_features,
act_layer=act_layer,
attn_layer=attn_layer,
attn_kwargs=attn_kwargs,
)
return cfg return cfg
def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True): def _dm_nfnet_cfg(
depths,
channels=(256, 512, 1536, 1536),
act_layer='gelu',
skipinit=True,
):
cfg = NfCfg( cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128, depths=depths,
bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit, channels=channels,
num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=dict(rd_ratio=0.5)) stem_type='deep_quad',
stem_chs=128,
group_size=128,
bottle_ratio=0.5,
extra_conv=True,
gamma_in_act=True,
same_padding=True,
skipinit=skipinit,
num_features=int(channels[-1] * 2.0),
act_layer=act_layer,
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.5),
)
return cfg return cfg
@ -278,7 +334,14 @@ def act_with_gamma(act_type, gamma: float = 1.):
class DownsampleAvg(nn.Module): class DownsampleAvg(nn.Module):
def __init__( def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, conv_layer=ScaledStdConv2d): self,
in_chs,
out_chs,
stride=1,
dilation=1,
first_dilation=None,
conv_layer=ScaledStdConv2d,
):
""" AvgPool Downsampling as in 'D' ResNet variants. Support for dilation.""" """ AvgPool Downsampling as in 'D' ResNet variants. Support for dilation."""
super(DownsampleAvg, self).__init__() super(DownsampleAvg, self).__init__()
avg_stride = stride if dilation == 1 else 1 avg_stride = stride if dilation == 1 else 1
@ -299,9 +362,26 @@ class NormFreeBlock(nn.Module):
""" """
def __init__( def __init__(
self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None, self,
alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False, in_chs,
skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.): out_chs=None,
stride=1,
dilation=1,
first_dilation=None,
alpha=1.0,
beta=1.0,
bottle_ratio=0.25,
group_size=None,
ch_div=1,
reg=True,
extra_conv=False,
skipinit=False,
attn_layer=None,
attn_gain=2.0,
act_layer=None,
conv_layer=None,
drop_path_rate=0.,
):
super().__init__() super().__init__()
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
out_chs = out_chs or in_chs out_chs = out_chs or in_chs
@ -316,7 +396,13 @@ class NormFreeBlock(nn.Module):
if in_chs != out_chs or stride != 1 or dilation != first_dilation: if in_chs != out_chs or stride != 1 or dilation != first_dilation:
self.downsample = DownsampleAvg( self.downsample = DownsampleAvg(
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer) in_chs,
out_chs,
stride=stride,
dilation=dilation,
first_dilation=first_dilation,
conv_layer=conv_layer,
)
else: else:
self.downsample = None self.downsample = None
@ -452,14 +538,33 @@ class NormFreeNet(nn.Module):
for what it is/does. Approx 8-10% throughput loss. for what it is/does. Approx 8-10% throughput loss.
""" """
def __init__( def __init__(
self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, self,
drop_rate=0., drop_path_rate=0. cfg: NfCfg,
num_classes=1000,
in_chans=3,
global_pool='avg',
output_stride=32,
drop_rate=0.,
drop_path_rate=0.,
**kwargs,
): ):
"""
Args:
cfg (NfCfg): Model architecture configuration
num_classes (int): Number of classifier classes (default: 1000)
in_chans (int): Number of input channels (default: 3)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.grad_checkpointing = False self.grad_checkpointing = False
cfg = replace(cfg, **kwargs)
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})." assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
if cfg.gamma_in_act: if cfg.gamma_in_act:
@ -472,7 +577,12 @@ class NormFreeNet(nn.Module):
stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div) stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)
self.stem, stem_stride, stem_feat = create_stem( self.stem, stem_stride, stem_feat = create_stem(
in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer) in_chans,
stem_chs,
cfg.stem_type,
conv_layer=conv_layer,
act_layer=act_layer,
)
self.feature_info = [stem_feat] self.feature_info = [stem_feat]
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]

@ -14,7 +14,7 @@ Weights from original impl have been modified
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import math import math
from dataclasses import dataclass from dataclasses import dataclass, replace
from functools import partial from functools import partial
from typing import Optional, Union, Callable from typing import Optional, Union, Callable
@ -237,7 +237,15 @@ def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_la
def create_shortcut( def create_shortcut(
downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None, preact=False): downsample_type,
in_chs,
out_chs,
kernel_size,
stride,
dilation=(1, 1),
norm_layer=None,
preact=False,
):
assert downsample_type in ('avg', 'conv1x1', '', None) assert downsample_type in ('avg', 'conv1x1', '', None)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact) dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact)
@ -259,9 +267,21 @@ class Bottleneck(nn.Module):
""" """
def __init__( def __init__(
self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25, self,
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, in_chs,
drop_block=None, drop_path_rate=0.): out_chs,
stride=1,
dilation=(1, 1),
bottle_ratio=1,
group_size=1,
se_ratio=0.25,
downsample='conv1x1',
linear_out=False,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
drop_block=None,
drop_path_rate=0.,
):
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
act_layer = get_act_layer(act_layer) act_layer = get_act_layer(act_layer)
bottleneck_chs = int(round(out_chs * bottle_ratio)) bottleneck_chs = int(round(out_chs * bottle_ratio))
@ -307,9 +327,21 @@ class PreBottleneck(nn.Module):
""" """
def __init__( def __init__(
self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25, self,
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, in_chs,
drop_block=None, drop_path_rate=0.): out_chs,
stride=1,
dilation=(1, 1),
bottle_ratio=1,
group_size=1,
se_ratio=0.25,
downsample='conv1x1',
linear_out=False,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
drop_block=None,
drop_path_rate=0.,
):
super(PreBottleneck, self).__init__() super(PreBottleneck, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer) norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
bottleneck_chs = int(round(out_chs * bottle_ratio)) bottleneck_chs = int(round(out_chs * bottle_ratio))
@ -353,8 +385,16 @@ class RegStage(nn.Module):
"""Stage (sequence of blocks w/ the same output shape).""" """Stage (sequence of blocks w/ the same output shape)."""
def __init__( def __init__(
self, depth, in_chs, out_chs, stride, dilation, self,
drop_path_rates=None, block_fn=Bottleneck, **block_kwargs): depth,
in_chs,
out_chs,
stride,
dilation,
drop_path_rates=None,
block_fn=Bottleneck,
**block_kwargs,
):
super(RegStage, self).__init__() super(RegStage, self).__init__()
self.grad_checkpointing = False self.grad_checkpointing = False
@ -367,8 +407,13 @@ class RegStage(nn.Module):
name = "b{}".format(i + 1) name = "b{}".format(i + 1)
self.add_module( self.add_module(
name, block_fn( name, block_fn(
block_in_chs, out_chs, stride=block_stride, dilation=block_dilation, block_in_chs,
drop_path_rate=dpr, **block_kwargs) out_chs,
stride=block_stride,
dilation=block_dilation,
drop_path_rate=dpr,
**block_kwargs,
)
) )
first_dilation = dilation first_dilation = dilation
@ -389,12 +434,35 @@ class RegNet(nn.Module):
""" """
def __init__( def __init__(
self, cfg: RegNetCfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', self,
drop_rate=0., drop_path_rate=0., zero_init_last=True): cfg: RegNetCfg,
in_chans=3,
num_classes=1000,
output_stride=32,
global_pool='avg',
drop_rate=0.,
drop_path_rate=0.,
zero_init_last=True,
**kwargs,
):
"""
Args:
cfg (RegNetCfg): Model architecture configuration
in_chans (int): Number of input channels (default: 3)
num_classes (int): Number of classifier classes (default: 1000)
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
global_pool (str): Global pooling type (default: 'avg')
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
zero_init_last (bool): Zero-init last weight of residual path
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
assert output_stride in (8, 16, 32) assert output_stride in (8, 16, 32)
cfg = replace(cfg, **kwargs) # update cfg with extra passed kwargs
# Construct the stem # Construct the stem
stem_width = cfg.stem_width stem_width = cfg.stem_width
@ -461,8 +529,12 @@ class RegNet(nn.Module):
dict(zip(arg_names, params)) for params in dict(zip(arg_names, params)) for params in
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_br, stage_gs, stage_dpr)] zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_br, stage_gs, stage_dpr)]
common_args = dict( common_args = dict(
downsample=cfg.downsample, se_ratio=cfg.se_ratio, linear_out=cfg.linear_out, downsample=cfg.downsample,
act_layer=cfg.act_layer, norm_layer=cfg.norm_layer) se_ratio=cfg.se_ratio,
linear_out=cfg.linear_out,
act_layer=cfg.act_layer,
norm_layer=cfg.norm_layer,
)
return per_stage_args, common_args return per_stage_args, common_args
@torch.jit.ignore @torch.jit.ignore
@ -518,7 +590,6 @@ def _init_weights(module, name='', zero_init_last=False):
def _filter_fn(state_dict): def _filter_fn(state_dict):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
if 'classy_state_dict' in state_dict: if 'classy_state_dict' in state_dict:
import re import re
state_dict = state_dict['classy_state_dict']['base_model']['model'] state_dict = state_dict['classy_state_dict']['base_model']['model']

@ -16,7 +16,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \ from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \
create_classifier get_act_layer, get_norm_layer, create_classifier
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model, model_entrypoint from ._registry import register_model, model_entrypoint
@ -500,7 +500,14 @@ class Bottleneck(nn.Module):
def downsample_conv( def downsample_conv(
in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
first_dilation=None,
norm_layer=None,
):
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1 first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
@ -514,7 +521,14 @@ def downsample_conv(
def downsample_avg( def downsample_avg(
in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
first_dilation=None,
norm_layer=None,
):
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
avg_stride = stride if dilation == 1 else 1 avg_stride = stride if dilation == 1 else 1
if stride == 1 and dilation == 1: if stride == 1 and dilation == 1:
@ -627,31 +641,6 @@ class ResNet(nn.Module):
SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64, SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64,
reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block
Parameters
----------
block : Block, class for the residual block. Options are BasicBlockGl, BottleneckGl.
layers : list of int, number of layers in each block
num_classes : int, default 1000, number of classification classes.
in_chans : int, default 3, number of input (color) channels.
output_stride : int, default 32, output stride of the network, 32, 16, or 8.
global_pool : str, Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
cardinality : int, default 1, number of convolution groups for 3x3 conv in Bottleneck.
base_width : int, default 64, factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
stem_width : int, default 64, number of channels in stem convolutions
stem_type : str, default ''
The type of stem:
* '', default - a single 7x7 conv with a width of stem_width
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
block_reduce_first : int, default 1
Reduction factor for first convolution output width of residual blocks, 1 for all archs except senets, where 2
down_kernel_size : int, default 1, kernel size of residual block downsample path, 1x1 for most, 3x3 for senets
avg_down : bool, default False, use average pooling for projection skip connection between stages/downsample.
act_layer : nn.Module, activation layer
norm_layer : nn.Module, normalization layer
aa_layer : nn.Module, anti-aliasing layer
drop_rate : float, default 0. Dropout probability before classifier, for training
""" """
def __init__( def __init__(
@ -679,6 +668,36 @@ class ResNet(nn.Module):
zero_init_last=True, zero_init_last=True,
block_args=None, block_args=None,
): ):
"""
Args:
block (nn.Module): class for the residual block. Options are BasicBlock, Bottleneck.
layers (List[int]) : number of layers in each block
num_classes (int): number of classification classes (default 1000)
in_chans (int): number of input (color) channels. (default 3)
output_stride (int): output stride of the network, 32, 16, or 8. (default 32)
global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg')
cardinality (int): number of convolution groups for 3x3 conv in Bottleneck. (default 1)
base_width (int): bottleneck channels factor. `planes * base_width / 64 * cardinality` (default 64)
stem_width (int): number of channels in stem convolutions (default 64)
stem_type (str): The type of stem (default ''):
* '', default - a single 7x7 conv with a width of stem_width
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
replace_stem_pool (bool): replace stem max-pooling layer with a 3x3 stride-2 convolution
block_reduce_first (int): Reduction factor for first convolution output width of residual blocks,
1 for all archs except senets, where 2 (default 1)
down_kernel_size (int): kernel size of residual block downsample path,
1x1 for most, 3x3 for senets (default: 1)
avg_down (bool): use avg pooling for projection skip connection between stages/downsample (default False)
act_layer (str, nn.Module): activation layer
norm_layer (str, nn.Module): normalization layer
aa_layer (nn.Module): anti-aliasing layer
drop_rate (float): Dropout probability before classifier, for training (default 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default 0.)
drop_block_rate (float): Drop block rate (default 0.)
zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight)
block_args (dict): Extra kwargs to pass through to block module
"""
super(ResNet, self).__init__() super(ResNet, self).__init__()
block_args = block_args or dict() block_args = block_args or dict()
assert output_stride in (8, 16, 32) assert output_stride in (8, 16, 32)
@ -686,6 +705,9 @@ class ResNet(nn.Module):
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.grad_checkpointing = False self.grad_checkpointing = False
act_layer = get_act_layer(act_layer)
norm_layer = get_norm_layer(norm_layer)
# Stem # Stem
deep_stem = 'deep' in stem_type deep_stem = 'deep' in stem_type
inplanes = stem_width * 2 if deep_stem else 64 inplanes = stem_width * 2 if deep_stem else 64

@ -37,7 +37,7 @@ import torch.nn as nn
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, FilterResponseNormTlu2d, \ from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, FilterResponseNormTlu2d, \
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv
from ._registry import register_model from ._registry import register_model
@ -276,8 +276,16 @@ class Bottleneck(nn.Module):
class DownsampleConv(nn.Module): class DownsampleConv(nn.Module):
def __init__( def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True, self,
conv_layer=None, norm_layer=None): in_chs,
out_chs,
stride=1,
dilation=1,
first_dilation=None,
preact=True,
conv_layer=None,
norm_layer=None,
):
super(DownsampleConv, self).__init__() super(DownsampleConv, self).__init__()
self.conv = conv_layer(in_chs, out_chs, 1, stride=stride) self.conv = conv_layer(in_chs, out_chs, 1, stride=stride)
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
@ -288,8 +296,16 @@ class DownsampleConv(nn.Module):
class DownsampleAvg(nn.Module): class DownsampleAvg(nn.Module):
def __init__( def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, self,
preact=True, conv_layer=None, norm_layer=None): in_chs,
out_chs,
stride=1,
dilation=1,
first_dilation=None,
preact=True,
conv_layer=None,
norm_layer=None,
):
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
super(DownsampleAvg, self).__init__() super(DownsampleAvg, self).__init__()
avg_stride = stride if dilation == 1 else 1 avg_stride = stride if dilation == 1 else 1
@ -334,9 +350,18 @@ class ResNetStage(nn.Module):
drop_path_rate = block_dpr[block_idx] if block_dpr else 0. drop_path_rate = block_dpr[block_idx] if block_dpr else 0.
stride = stride if block_idx == 0 else 1 stride = stride if block_idx == 0 else 1
self.blocks.add_module(str(block_idx), block_fn( self.blocks.add_module(str(block_idx), block_fn(
prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups, prev_chs,
first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate, out_chs,
**layer_kwargs, **block_kwargs)) stride=stride,
dilation=dilation,
bottle_ratio=bottle_ratio,
groups=groups,
first_dilation=first_dilation,
proj_layer=proj_layer,
drop_path_rate=drop_path_rate,
**layer_kwargs,
**block_kwargs,
))
prev_chs = out_chs prev_chs = out_chs
first_dilation = dilation first_dilation = dilation
proj_layer = None proj_layer = None
@ -413,21 +438,49 @@ class ResNetV2(nn.Module):
avg_down=False, avg_down=False,
preact=True, preact=True,
act_layer=nn.ReLU, act_layer=nn.ReLU,
conv_layer=StdConv2d,
norm_layer=partial(GroupNormAct, num_groups=32), norm_layer=partial(GroupNormAct, num_groups=32),
conv_layer=StdConv2d,
drop_rate=0., drop_rate=0.,
drop_path_rate=0., drop_path_rate=0.,
zero_init_last=False, zero_init_last=False,
): ):
"""
Args:
layers (List[int]) : number of layers in each block
channels (List[int]) : number of channels in each block:
num_classes (int): number of classification classes (default 1000)
in_chans (int): number of input (color) channels. (default 3)
global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg')
output_stride (int): output stride of the network, 32, 16, or 8. (default 32)
width_factor (int): channel (width) multiplication factor
stem_chs (int): stem width (default: 64)
stem_type (str): stem type (default: '' == 7x7)
avg_down (bool): average pooling in residual downsampling (default: False)
preact (bool): pre-activiation (default: True)
act_layer (Union[str, nn.Module]): activation layer
norm_layer (Union[str, nn.Module]): normalization layer
conv_layer (nn.Module): convolution module
drop_rate: classifier dropout rate (default: 0.)
drop_path_rate: stochastic depth rate (default: 0.)
zero_init_last: zero-init last weight in residual path (default: False)
"""
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
wf = width_factor wf = width_factor
norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
act_layer = get_act_layer(act_layer)
self.feature_info = [] self.feature_info = []
stem_chs = make_div(stem_chs * wf) stem_chs = make_div(stem_chs * wf)
self.stem = create_resnetv2_stem( self.stem = create_resnetv2_stem(
in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer) in_chans,
stem_chs,
stem_type,
preact,
conv_layer=conv_layer,
norm_layer=norm_layer,
)
stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm' stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm'
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat)) self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))

@ -697,6 +697,13 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({ default_cfgs = generate_default_cfgs({
# re-finetuned augreg 21k FT on in1k weights
'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(),
'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg(
hf_hub_id='timm/'),
# How to train your ViT (augreg) weights, pretrained on 21k FT on in1k # How to train your ViT (augreg) weights, pretrained on 21k FT on in1k
'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg( 'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
@ -751,13 +758,6 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/', hf_hub_id='timm/',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
# re-finetuned augreg 21k FT on in1k weights
'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(),
'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg(
hf_hub_id='timm/'),
# patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k # patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k
'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg( 'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
@ -802,7 +802,6 @@ default_cfgs = generate_default_cfgs({
'vit_giant_patch14_224.untrained': _cfg(url=''), 'vit_giant_patch14_224.untrained': _cfg(url=''),
'vit_gigantic_patch14_224.untrained': _cfg(url=''), 'vit_gigantic_patch14_224.untrained': _cfg(url=''),
# patch models, imagenet21k (weights from official Google JAX impl) # patch models, imagenet21k (weights from official Google JAX impl)
'vit_large_patch32_224.orig_in21k': _cfg( 'vit_large_patch32_224.orig_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
@ -869,7 +868,6 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/', hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
# ViT ImageNet-21K-P pretraining by MILL # ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil.in21k': _cfg( 'vit_base_patch16_224_miil.in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
@ -880,7 +878,7 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/', hf_hub_id='timm/',
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'), mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'),
# custom timm variants # Custom timm variants
'vit_base_patch16_rpn_224.in1k': _cfg( 'vit_base_patch16_rpn_224.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth',
hf_hub_id='timm/'), hf_hub_id='timm/'),
@ -896,52 +894,6 @@ default_cfgs = generate_default_cfgs({
'vit_base_patch16_gap_224': _cfg(), 'vit_base_patch16_gap_224': _cfg(),
# CLIP pretrained image tower and related fine-tuned weights # CLIP pretrained image tower and related fine-tuned weights
'vit_base_patch32_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_base_patch16_clip_224.laion2b': _cfg(
#hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768),
'vit_huge_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_giant_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg(
hf_hub_id='',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg( 'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
@ -973,28 +925,52 @@ default_cfgs = generate_default_cfgs({
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
'vit_base_patch32_clip_224.laion2b_ft_in12k': _cfg( 'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg(
#hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k', # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg( 'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg( crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg( 'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
'vit_base_patch32_clip_224.openai': _cfg( 'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch16_clip_224.openai': _cfg( 'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_large_patch14_clip_224.openai': _cfg( 'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg(
hf_hub_id='',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
'vit_base_patch32_clip_224.openai_ft_in1k': _cfg( 'vit_base_patch32_clip_224.openai_ft_in1k': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
@ -1010,30 +986,21 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/', hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg( 'vit_base_patch32_clip_224.laion2b_ft_in12k': _cfg(
#hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k', #hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg( 'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg(
'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821),
'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg( 'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
'vit_base_patch32_clip_224.openai_ft_in12k': _cfg( 'vit_base_patch32_clip_224.openai_ft_in12k': _cfg(
#hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k', # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
'vit_base_patch16_clip_224.openai_ft_in12k': _cfg( 'vit_base_patch16_clip_224.openai_ft_in12k': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
@ -1042,6 +1009,37 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/', hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
'vit_base_patch32_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_base_patch16_clip_224.laion2b': _cfg(
# hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768),
'vit_huge_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_giant_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_base_patch32_clip_224.openai': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_base_patch16_clip_224.openai': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_large_patch14_clip_224.openai': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
# experimental (may be removed) # experimental (may be removed)
'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), 'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95), 'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95),
@ -1152,8 +1150,8 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
def vit_tiny_patch16_224(pretrained=False, **kwargs): def vit_tiny_patch16_224(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16) """ ViT-Tiny (Vit-Ti/16)
""" """
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1161,8 +1159,8 @@ def vit_tiny_patch16_224(pretrained=False, **kwargs):
def vit_tiny_patch16_384(pretrained=False, **kwargs): def vit_tiny_patch16_384(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16) @ 384x384. """ ViT-Tiny (Vit-Ti/16) @ 384x384.
""" """
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1170,8 +1168,8 @@ def vit_tiny_patch16_384(pretrained=False, **kwargs):
def vit_small_patch32_224(pretrained=False, **kwargs): def vit_small_patch32_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/32) """ ViT-Small (ViT-S/32)
""" """
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1179,8 +1177,8 @@ def vit_small_patch32_224(pretrained=False, **kwargs):
def vit_small_patch32_384(pretrained=False, **kwargs): def vit_small_patch32_384(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/32) at 384x384. """ ViT-Small (ViT-S/32) at 384x384.
""" """
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1188,8 +1186,8 @@ def vit_small_patch32_384(pretrained=False, **kwargs):
def vit_small_patch16_224(pretrained=False, **kwargs): def vit_small_patch16_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16) """ ViT-Small (ViT-S/16)
""" """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1197,8 +1195,8 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
def vit_small_patch16_384(pretrained=False, **kwargs): def vit_small_patch16_384(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16) """ ViT-Small (ViT-S/16)
""" """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1206,8 +1204,8 @@ def vit_small_patch16_384(pretrained=False, **kwargs):
def vit_small_patch8_224(pretrained=False, **kwargs): def vit_small_patch8_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/8) """ ViT-Small (ViT-S/8)
""" """
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1216,8 +1214,8 @@ def vit_base_patch32_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer. ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer.
""" """
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1226,8 +1224,8 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
""" """
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1236,8 +1234,8 @@ def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
""" """
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1246,8 +1244,8 @@ def vit_base_patch16_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
""" """
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1256,8 +1254,8 @@ def vit_base_patch8_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
""" """
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1265,8 +1263,8 @@ def vit_base_patch8_224(pretrained=False, **kwargs):
def vit_large_patch32_224(pretrained=False, **kwargs): def vit_large_patch32_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
""" """
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1275,8 +1273,8 @@ def vit_large_patch32_384(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
""" """
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1285,8 +1283,8 @@ def vit_large_patch16_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
""" """
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1295,8 +1293,8 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
""" """
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1304,8 +1302,8 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
def vit_large_patch14_224(pretrained=False, **kwargs): def vit_large_patch14_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14) """ ViT-Large model (ViT-L/14)
""" """
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, **kwargs) model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1313,8 +1311,8 @@ def vit_large_patch14_224(pretrained=False, **kwargs):
def vit_huge_patch14_224(pretrained=False, **kwargs): def vit_huge_patch14_224(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
""" """
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs) model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16)
model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1322,8 +1320,8 @@ def vit_huge_patch14_224(pretrained=False, **kwargs):
def vit_giant_patch14_224(pretrained=False, **kwargs): def vit_giant_patch14_224(pretrained=False, **kwargs):
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
""" """
model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs) model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16)
model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1331,8 +1329,9 @@ def vit_giant_patch14_224(pretrained=False, **kwargs):
def vit_gigantic_patch14_224(pretrained=False, **kwargs): def vit_gigantic_patch14_224(pretrained=False, **kwargs):
""" ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
""" """
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs) model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer(
'vit_gigantic_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1341,8 +1340,9 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
""" """
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False)
model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer(
'vit_base_patch16_224_miil', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1352,8 +1352,9 @@ def vit_medium_patch16_gap_240(pretrained=False, **kwargs):
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
model = _create_vision_transformer('vit_medium_patch16_gap_240', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer(
'vit_medium_patch16_gap_240', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1363,8 +1364,9 @@ def vit_medium_patch16_gap_256(pretrained=False, **kwargs):
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
model = _create_vision_transformer('vit_medium_patch16_gap_256', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer(
'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1374,8 +1376,9 @@ def vit_medium_patch16_gap_384(pretrained=False, **kwargs):
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
model = _create_vision_transformer('vit_medium_patch16_gap_384', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer(
'vit_medium_patch16_gap_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1384,9 +1387,9 @@ def vit_base_patch16_gap_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256 """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
global_pool=kwargs.get('global_pool', 'avg'), fc_norm=False, **kwargs) model = _create_vision_transformer(
model = _create_vision_transformer('vit_base_patch16_gap_224', pretrained=pretrained, **model_kwargs) 'vit_base_patch16_gap_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1395,8 +1398,9 @@ def vit_base_patch32_clip_224(pretrained=False, **kwargs):
""" ViT-B/32 CLIP image tower @ 224x224 """ ViT-B/32 CLIP image tower @ 224x224
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer('vit_base_patch32_clip_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer(
'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1405,8 +1409,9 @@ def vit_base_patch32_clip_384(pretrained=False, **kwargs):
""" ViT-B/32 CLIP image tower @ 384x384 """ ViT-B/32 CLIP image tower @ 384x384
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer('vit_base_patch32_clip_384', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer(
'vit_base_patch32_clip_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1415,8 +1420,9 @@ def vit_base_patch32_clip_448(pretrained=False, **kwargs):
""" ViT-B/32 CLIP image tower @ 448x448 """ ViT-B/32 CLIP image tower @ 448x448
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer('vit_base_patch32_clip_448', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer(
'vit_base_patch32_clip_448', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1424,9 +1430,9 @@ def vit_base_patch32_clip_448(pretrained=False, **kwargs):
def vit_base_patch16_clip_224(pretrained=False, **kwargs): def vit_base_patch16_clip_224(pretrained=False, **kwargs):
""" ViT-B/16 CLIP image tower """ ViT-B/16 CLIP image tower
""" """
model_kwargs = dict( model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer(
model = _create_vision_transformer('vit_base_patch16_clip_224', pretrained=pretrained, **model_kwargs) 'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1434,9 +1440,9 @@ def vit_base_patch16_clip_224(pretrained=False, **kwargs):
def vit_base_patch16_clip_384(pretrained=False, **kwargs): def vit_base_patch16_clip_384(pretrained=False, **kwargs):
""" ViT-B/16 CLIP image tower @ 384x384 """ ViT-B/16 CLIP image tower @ 384x384
""" """
model_kwargs = dict( model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer(
model = _create_vision_transformer('vit_base_patch16_clip_384', pretrained=pretrained, **model_kwargs) 'vit_base_patch16_clip_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1444,9 +1450,9 @@ def vit_base_patch16_clip_384(pretrained=False, **kwargs):
def vit_large_patch14_clip_224(pretrained=False, **kwargs): def vit_large_patch14_clip_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14) CLIP image tower """ ViT-Large model (ViT-L/14) CLIP image tower
""" """
model_kwargs = dict( model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer(
model = _create_vision_transformer('vit_large_patch14_clip_224', pretrained=pretrained, **model_kwargs) 'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1454,9 +1460,9 @@ def vit_large_patch14_clip_224(pretrained=False, **kwargs):
def vit_large_patch14_clip_336(pretrained=False, **kwargs): def vit_large_patch14_clip_336(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336
""" """
model_kwargs = dict( model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer(
model = _create_vision_transformer('vit_large_patch14_clip_336', pretrained=pretrained, **model_kwargs) 'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1464,9 +1470,9 @@ def vit_large_patch14_clip_336(pretrained=False, **kwargs):
def vit_huge_patch14_clip_224(pretrained=False, **kwargs): def vit_huge_patch14_clip_224(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) CLIP image tower. """ ViT-Huge model (ViT-H/14) CLIP image tower.
""" """
model_kwargs = dict( model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer(
model = _create_vision_transformer('vit_huge_patch14_clip_224', pretrained=pretrained, **model_kwargs) 'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1474,9 +1480,9 @@ def vit_huge_patch14_clip_224(pretrained=False, **kwargs):
def vit_huge_patch14_clip_336(pretrained=False, **kwargs): def vit_huge_patch14_clip_336(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336 """ ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336
""" """
model_kwargs = dict( model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer(
model = _create_vision_transformer('vit_huge_patch14_clip_336', pretrained=pretrained, **model_kwargs) 'vit_huge_patch14_clip_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1486,9 +1492,9 @@ def vit_giant_patch14_clip_224(pretrained=False, **kwargs):
Pretrained weights from CLIP image tower. Pretrained weights from CLIP image tower.
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) model = _create_vision_transformer(
model = _create_vision_transformer('vit_giant_patch14_clip_224', pretrained=pretrained, **model_kwargs) 'vit_giant_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1498,8 +1504,9 @@ def vit_giant_patch14_clip_224(pretrained=False, **kwargs):
def vit_base_patch32_plus_256(pretrained=False, **kwargs): def vit_base_patch32_plus_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32+) """ ViT-Base (ViT-B/32+)
""" """
model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5)
model = _create_vision_transformer('vit_base_patch32_plus_256', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer(
'vit_base_patch32_plus_256', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1507,8 +1514,9 @@ def vit_base_patch32_plus_256(pretrained=False, **kwargs):
def vit_base_patch16_plus_240(pretrained=False, **kwargs): def vit_base_patch16_plus_240(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16+) """ ViT-Base (ViT-B/16+)
""" """
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5)
model = _create_vision_transformer('vit_base_patch16_plus_240', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer(
'vit_base_patch16_plus_240', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1517,9 +1525,10 @@ def vit_base_patch16_rpn_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ residual post-norm """ ViT-Base (ViT-B/16) w/ residual post-norm
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False, patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
block_fn=ResPostBlock, global_pool=kwargs.pop('global_pool', 'avg'), **kwargs) class_token=False, block_fn=ResPostBlock, global_pool='avg')
model = _create_vision_transformer('vit_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer(
'vit_base_patch16_rpn_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1529,8 +1538,9 @@ def vit_small_patch16_36x1_224(pretrained=False, **kwargs):
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
""" """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5)
model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer(
'vit_small_patch16_36x1_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1541,8 +1551,9 @@ def vit_small_patch16_18x2_224(pretrained=False, **kwargs):
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs) patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock)
model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer(
'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1551,27 +1562,26 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
""" ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
""" """
model_kwargs = dict( model_kwargs = dict(patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock)
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) model = _create_vision_transformer(
model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) 'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@register_model @register_model
def eva_large_patch14_196(pretrained=False, **kwargs): def eva_large_patch14_196(pretrained=False, **kwargs):
""" EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain""" """ EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain"""
model_kwargs = dict( model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg')
patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) model = _create_vision_transformer(
model = _create_vision_transformer('eva_large_patch14_196', pretrained=pretrained, **model_kwargs) 'eva_large_patch14_196', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@register_model @register_model
def eva_large_patch14_336(pretrained=False, **kwargs): def eva_large_patch14_336(pretrained=False, **kwargs):
""" EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain""" """ EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain"""
model_kwargs = dict( model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg')
patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **model_kwargs)
return model return model
@ -1579,8 +1589,8 @@ def eva_large_patch14_336(pretrained=False, **kwargs):
def flexivit_small(pretrained=False, **kwargs): def flexivit_small(pretrained=False, **kwargs):
""" FlexiViT-Small """ FlexiViT-Small
""" """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True)
model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1588,8 +1598,8 @@ def flexivit_small(pretrained=False, **kwargs):
def flexivit_base(pretrained=False, **kwargs): def flexivit_base(pretrained=False, **kwargs):
""" FlexiViT-Base """ FlexiViT-Base
""" """
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1597,6 +1607,6 @@ def flexivit_base(pretrained=False, **kwargs):
def flexivit_large(pretrained=False, **kwargs): def flexivit_large(pretrained=False, **kwargs):
""" FlexiViT-Large """ FlexiViT-Large
""" """
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True)
model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model

@ -181,8 +181,18 @@ class SequentialAppendList(nn.Sequential):
class OsaBlock(nn.Module): class OsaBlock(nn.Module):
def __init__( def __init__(
self, in_chs, mid_chs, out_chs, layer_per_block, residual=False, self,
depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None): in_chs,
mid_chs,
out_chs,
layer_per_block,
residual=False,
depthwise=False,
attn='',
norm_layer=BatchNormAct2d,
act_layer=nn.ReLU,
drop_path=None,
):
super(OsaBlock, self).__init__() super(OsaBlock, self).__init__()
self.residual = residual self.residual = residual
@ -232,9 +242,20 @@ class OsaBlock(nn.Module):
class OsaStage(nn.Module): class OsaStage(nn.Module):
def __init__( def __init__(
self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True, self,
residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, in_chs,
drop_path_rates=None): mid_chs,
out_chs,
block_per_stage,
layer_per_block,
downsample=True,
residual=True,
depthwise=False,
attn='ese',
norm_layer=BatchNormAct2d,
act_layer=nn.ReLU,
drop_path_rates=None,
):
super(OsaStage, self).__init__() super(OsaStage, self).__init__()
self.grad_checkpointing = False self.grad_checkpointing = False
@ -270,16 +291,38 @@ class OsaStage(nn.Module):
class VovNet(nn.Module): class VovNet(nn.Module):
def __init__( def __init__(
self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4, self,
output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.): cfg,
""" VovNet (v2) in_chans=3,
num_classes=1000,
global_pool='avg',
output_stride=32,
norm_layer=BatchNormAct2d,
act_layer=nn.ReLU,
drop_rate=0.,
drop_path_rate=0.,
**kwargs,
):
"""
Args:
cfg (dict): Model architecture configuration
in_chans (int): Number of input channels (default: 3)
num_classes (int): Number of classifier classes (default: 1000)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
norm_layer (Union[str, nn.Module]): normalization layer
act_layer (Union[str, nn.Module]): activation layer
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
kwargs (dict): Extra kwargs overlayed onto cfg
""" """
super(VovNet, self).__init__() super(VovNet, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
assert stem_stride in (4, 2)
assert output_stride == 32 # FIXME support dilation assert output_stride == 32 # FIXME support dilation
cfg = dict(cfg, **kwargs)
stem_stride = cfg.get("stem_stride", 4)
stem_chs = cfg["stem_chs"] stem_chs = cfg["stem_chs"]
stage_conv_chs = cfg["stage_conv_chs"] stage_conv_chs = cfg["stage_conv_chs"]
stage_out_chs = cfg["stage_out_chs"] stage_out_chs = cfg["stage_out_chs"]
@ -307,9 +350,15 @@ class VovNet(nn.Module):
for i in range(4): # num_stages for i in range(4): # num_stages
downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4 downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4
stages += [OsaStage( stages += [OsaStage(
in_ch_list[i], stage_conv_chs[i], stage_out_chs[i], block_per_stage[i], layer_per_block, in_ch_list[i],
downsample=downsample, drop_path_rates=stage_dpr[i], **stage_args) stage_conv_chs[i],
] stage_out_chs[i],
block_per_stage[i],
layer_per_block,
downsample=downsample,
drop_path_rates=stage_dpr[i],
**stage_args,
)]
self.num_features = stage_out_chs[i] self.num_features = stage_out_chs[i]
current_stride *= 2 if downsample else 1 current_stride *= 2 if downsample else 1
self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')] self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
@ -324,7 +373,6 @@ class VovNet(nn.Module):
elif isinstance(m, nn.Linear): elif isinstance(m, nn.Linear):
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(

@ -8,7 +8,7 @@ from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\
from .jit import set_jit_legacy, set_jit_fuser from .jit import set_jit_legacy, set_jit_fuser
from .log import setup_default_logging, FormatterNoInfo from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg from .misc import natural_key, add_bool_arg, ParseKwargs
from .model import unwrap_model, get_state_dict, freeze, unfreeze from .model import unwrap_model, get_state_dict, freeze, unfreeze
from .model_ema import ModelEma, ModelEmaV2 from .model_ema import ModelEma, ModelEmaV2
from .random import random_seed from .random import random_seed

@ -2,6 +2,8 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import argparse
import ast
import re import re
@ -16,3 +18,15 @@ def add_bool_arg(parser, name, default=False, help=''):
group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
parser.set_defaults(**{dest_name: default}) parser.set_defaults(**{dest_name: default})
class ParseKwargs(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
kw = {}
for value in values:
key, value = value.split('=')
try:
kw[key] = ast.literal_eval(value)
except ValueError:
kw[key] = str(value) # fallback to string (avoid need to escape on command line)
setattr(namespace, self.dest, kw)

@ -89,56 +89,58 @@ parser.add_argument('--data-dir', metavar='DIR',
parser.add_argument('--dataset', metavar='NAME', default='', parser.add_argument('--dataset', metavar='NAME', default='',
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)') help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
group.add_argument('--train-split', metavar='NAME', default='train', group.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)') help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation', group.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)') help='dataset validation split (default: validation)')
group.add_argument('--dataset-download', action='store_true', default=False, group.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.') help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
group.add_argument('--class-map', default='', type=str, metavar='FILENAME', group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")') help='path to class to idx mapping file (default: "")')
# Model parameters # Model parameters
group = parser.add_argument_group('Model parameters') group = parser.add_argument_group('Model parameters')
group.add_argument('--model', default='resnet50', type=str, metavar='MODEL', group.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50")') help='Name of model to train (default: "resnet50")')
group.add_argument('--pretrained', action='store_true', default=False, group.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)') help='Start with pretrained version of specified network (if avail)')
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)') help='Initialize model from this checkpoint (default: none)')
group.add_argument('--resume', default='', type=str, metavar='PATH', group.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)') help='Resume full model and optimizer state from checkpoint (default: none)')
group.add_argument('--no-resume-opt', action='store_true', default=False, group.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model') help='prevent resume of optimizer state when resuming model')
group.add_argument('--num-classes', type=int, default=None, metavar='N', group.add_argument('--num-classes', type=int, default=None, metavar='N',
help='number of label classes (Model default if None)') help='number of label classes (Model default if None)')
group.add_argument('--gp', default=None, type=str, metavar='POOL', group.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
group.add_argument('--img-size', type=int, default=None, metavar='N', group.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image size (default: None => model default)') help='Image size (default: None => model default)')
group.add_argument('--in-chans', type=int, default=None, metavar='N', group.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)') help='Image input channels (default: None => 3)')
group.add_argument('--input-size', default=None, nargs=3, type=int, group.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') metavar='N N N',
help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
group.add_argument('--crop-pct', default=None, type=float, group.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop percent (for validation only)') metavar='N', help='Input image center crop percent (for validation only)')
group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset') help='Override mean pixel value of dataset')
group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of dataset') help='Override std deviation of dataset')
group.add_argument('--interpolation', default='', type=str, metavar='NAME', group.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)') help='Image resize interpolation type (overrides model)')
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='Input batch size for training (default: 128)') help='Input batch size for training (default: 128)')
group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
help='Validation batch size override (default: None)') help='Validation batch size override (default: None)')
group.add_argument('--channels-last', action='store_true', default=False, group.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout') help='Use channels_last memory layout')
group.add_argument('--fuser', default='', type=str, group.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--grad-checkpointing', action='store_true', default=False, group.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages') help='Enable gradient checkpointing through model blocks/stages')
group.add_argument('--fast-norm', default=False, action='store_true', group.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm') help='enable experimental fast-norm')
group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
scripting_group = group.add_mutually_exclusive_group() scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
@ -151,199 +153,200 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true
# Optimizer parameters # Optimizer parameters
group = parser.add_argument_group('Optimizer parameters') group = parser.add_argument_group('Optimizer parameters')
group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd")') help='Optimizer (default: "sgd")')
group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)') help='Optimizer Epsilon (default: None, use opt default)')
group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)') help='Optimizer Betas (default: None, use opt default)')
group.add_argument('--momentum', type=float, default=0.9, metavar='M', group.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)') help='Optimizer momentum (default: 0.9)')
group.add_argument('--weight-decay', type=float, default=2e-5, group.add_argument('--weight-decay', type=float, default=2e-5,
help='weight decay (default: 2e-5)') help='weight decay (default: 2e-5)')
group.add_argument('--clip-grad', type=float, default=None, metavar='NORM', group.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)') help='Clip gradient norm (default: None, no clipping)')
group.add_argument('--clip-mode', type=str, default='norm', group.add_argument('--clip-mode', type=str, default='norm',
help='Gradient clipping mode. One of ("norm", "value", "agc")') help='Gradient clipping mode. One of ("norm", "value", "agc")')
group.add_argument('--layer-decay', type=float, default=None, group.add_argument('--layer-decay', type=float, default=None,
help='layer-wise learning rate decay (default: None)') help='layer-wise learning rate decay (default: None)')
group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
# Learning rate schedule parameters # Learning rate schedule parameters
group = parser.add_argument_group('Learning rate schedule parameters') group = parser.add_argument_group('Learning rate schedule parameters')
group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER', group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER',
help='LR scheduler (default: "step"') help='LR scheduler (default: "step"')
group.add_argument('--sched-on-updates', action='store_true', default=False, group.add_argument('--sched-on-updates', action='store_true', default=False,
help='Apply LR scheduler step on update instead of epoch end.') help='Apply LR scheduler step on update instead of epoch end.')
group.add_argument('--lr', type=float, default=None, metavar='LR', group.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate, overrides lr-base if set (default: None)') help='learning rate, overrides lr-base if set (default: None)')
group.add_argument('--lr-base', type=float, default=0.1, metavar='LR', group.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
help='base learning rate: lr = lr_base * global_batch_size / base_size') help='base learning rate: lr = lr_base * global_batch_size / base_size')
group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV', group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
help='base learning rate batch size (divisor, default: 256).') help='base learning rate batch size (divisor, default: 256).')
group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE', group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)') help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages') help='learning rate noise on/off epoch percentages')
group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)') help='learning rate noise limit percent (default: 0.67)')
group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)') help='learning rate noise std-dev (default: 1.0)')
group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)') help='learning rate cycle len multiplier (default: 1.0)')
group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
help='amount to decay each learning rate cycle (default: 0.5)') help='amount to decay each learning rate cycle (default: 0.5)')
group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit, cycles enabled if > 1') help='learning rate cycle limit, cycles enabled if > 1')
group.add_argument('--lr-k-decay', type=float, default=1.0, group.add_argument('--lr-k-decay', type=float, default=1.0,
help='learning rate k-decay for cosine/poly (default: 1.0)') help='learning rate k-decay for cosine/poly (default: 1.0)')
group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR', group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',
help='warmup learning rate (default: 1e-5)') help='warmup learning rate (default: 1e-5)')
group.add_argument('--min-lr', type=float, default=0, metavar='LR', group.add_argument('--min-lr', type=float, default=0, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (default: 0)') help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
group.add_argument('--epochs', type=int, default=300, metavar='N', group.add_argument('--epochs', type=int, default=300, metavar='N',
help='number of epochs to train (default: 300)') help='number of epochs to train (default: 300)')
group.add_argument('--epoch-repeats', type=float, default=0., metavar='N', group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
group.add_argument('--start-epoch', default=None, type=int, metavar='N', group.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)') help='manual epoch number (useful on restarts)')
group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES", group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES",
help='list of decay epoch indices for multistep lr. must be increasing') help='list of decay epoch indices for multistep lr. must be increasing')
group.add_argument('--decay-epochs', type=float, default=90, metavar='N', group.add_argument('--decay-epochs', type=float, default=90, metavar='N',
help='epoch interval to decay LR') help='epoch interval to decay LR')
group.add_argument('--warmup-epochs', type=int, default=5, metavar='N', group.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports') help='epochs to warmup LR, if scheduler supports')
group.add_argument('--warmup-prefix', action='store_true', default=False, group.add_argument('--warmup-prefix', action='store_true', default=False,
help='Exclude warmup period from decay schedule.'), help='Exclude warmup period from decay schedule.'),
group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N', group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends') help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
group.add_argument('--patience-epochs', type=int, default=10, metavar='N', group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10)') help='patience epochs for Plateau LR scheduler (default: 10)')
group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)') help='LR decay rate (default: 0.1)')
# Augmentation & regularization parameters # Augmentation & regularization parameters
group = parser.add_argument_group('Augmentation and regularization parameters') group = parser.add_argument_group('Augmentation and regularization parameters')
group.add_argument('--no-aug', action='store_true', default=False, group.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args') help='Disable all training augmentation, override other train aug args')
group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)') help='Random resize scale (default: 0.08 1.0)')
group.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)') help='Random resize aspect ratio (default: 0.75 1.33)')
group.add_argument('--hflip', type=float, default=0.5, group.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability') help='Horizontal flip training aug probability')
group.add_argument('--vflip', type=float, default=0., group.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability') help='Vertical flip training aug probability')
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)') help='Color jitter factor (default: 0.4)')
group.add_argument('--aa', type=str, default=None, metavar='NAME', group.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'), help='Use AutoAugment policy. "v0" or "original". (default: None)'),
group.add_argument('--aug-repeats', type=float, default=0, group.add_argument('--aug-repeats', type=float, default=0,
help='Number of augmentation repetitions (distributed training only) (default: 0)') help='Number of augmentation repetitions (distributed training only) (default: 0)')
group.add_argument('--aug-splits', type=int, default=0, group.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)') help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
group.add_argument('--jsd-loss', action='store_true', default=False, group.add_argument('--jsd-loss', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
group.add_argument('--bce-loss', action='store_true', default=False, group.add_argument('--bce-loss', action='store_true', default=False,
help='Enable BCE loss w/ Mixup/CutMix use.') help='Enable BCE loss w/ Mixup/CutMix use.')
group.add_argument('--bce-target-thresh', type=float, default=None, group.add_argument('--bce-target-thresh', type=float, default=None,
help='Threshold for binarizing softened BCE targets (default: None, disabled)') help='Threshold for binarizing softened BCE targets (default: None, disabled)')
group.add_argument('--reprob', type=float, default=0., metavar='PCT', group.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)') help='Random erase prob (default: 0.)')
group.add_argument('--remode', type=str, default='pixel', group.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")') help='Random erase mode (default: "pixel")')
group.add_argument('--recount', type=int, default=1, group.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)') help='Random erase count (default: 1)')
group.add_argument('--resplit', action='store_true', default=False, group.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split') help='Do not random erase first (clean) augmentation split')
group.add_argument('--mixup', type=float, default=0.0, group.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)') help='mixup alpha, mixup enabled if > 0. (default: 0.)')
group.add_argument('--cutmix', type=float, default=0.0, group.add_argument('--cutmix', type=float, default=0.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
group.add_argument('--mixup-prob', type=float, default=1.0, group.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled') help='Probability of performing mixup or cutmix when either/both is enabled')
group.add_argument('--mixup-switch-prob', type=float, default=0.5, group.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled') help='Probability of switching to cutmix when both mixup and cutmix enabled')
group.add_argument('--mixup-mode', type=str, default='batch', group.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)') help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
group.add_argument('--smoothing', type=float, default=0.1, group.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)') help='Label smoothing (default: 0.1)')
group.add_argument('--train-interpolation', type=str, default='random', group.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")') help='Training interpolation (random, bilinear, bicubic default: "random")')
group.add_argument('--drop', type=float, default=0.0, metavar='PCT', group.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)') help='Dropout rate (default: 0.)')
group.add_argument('--drop-connect', type=float, default=None, metavar='PCT', group.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)') help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
group.add_argument('--drop-path', type=float, default=None, metavar='PCT', group.add_argument('--drop-path', type=float, default=None, metavar='PCT',
help='Drop path rate (default: None)') help='Drop path rate (default: None)')
group.add_argument('--drop-block', type=float, default=None, metavar='PCT', group.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)') help='Drop block rate (default: None)')
# Batch norm parameters (only works with gen_efficientnet based models currently) # Batch norm parameters (only works with gen_efficientnet based models currently)
group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.') group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
group.add_argument('--bn-momentum', type=float, default=None, group.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)') help='BatchNorm momentum override (if not None)')
group.add_argument('--bn-eps', type=float, default=None, group.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)') help='BatchNorm epsilon override (if not None)')
group.add_argument('--sync-bn', action='store_true', group.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
group.add_argument('--dist-bn', type=str, default='reduce', group.add_argument('--dist-bn', type=str, default='reduce',
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
group.add_argument('--split-bn', action='store_true', group.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.') help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average # Model Exponential Moving Average
group = parser.add_argument_group('Model exponential moving average parameters') group = parser.add_argument_group('Model exponential moving average parameters')
group.add_argument('--model-ema', action='store_true', default=False, group.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights') help='Enable tracking moving average of model weights')
group.add_argument('--model-ema-force-cpu', action='store_true', default=False, group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
group.add_argument('--model-ema-decay', type=float, default=0.9998, group.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)') help='decay factor for model weights moving average (default: 0.9998)')
# Misc # Misc
group = parser.add_argument_group('Miscellaneous parameters') group = parser.add_argument_group('Miscellaneous parameters')
group.add_argument('--seed', type=int, default=42, metavar='S', group.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)') help='random seed (default: 42)')
group.add_argument('--worker-seeding', type=str, default='all', group.add_argument('--worker-seeding', type=str, default='all',
help='worker seed mode (default: all)') help='worker seed mode (default: all)')
group.add_argument('--log-interval', type=int, default=50, metavar='N', group.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status') help='how many batches to wait before logging training status')
group.add_argument('--recovery-interval', type=int, default=0, metavar='N', group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint') help='how many batches to wait before writing recovery checkpoint')
group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
help='number of checkpoints to keep (default: 10)') help='number of checkpoints to keep (default: 10)')
group.add_argument('-j', '--workers', type=int, default=4, metavar='N', group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 4)') help='how many training processes to use (default: 4)')
group.add_argument('--save-images', action='store_true', default=False, group.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging') help='save images of input bathes every log interval for debugging')
group.add_argument('--amp', action='store_true', default=False, group.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training') help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str, group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)') help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str, group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)') help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False, group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.') help='Force broadcast buffers for native DDP to off.')
group.add_argument('--pin-mem', action='store_true', default=False, group.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False, group.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher') help='disable fast prefetcher')
group.add_argument('--output', default='', type=str, metavar='PATH', group.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)') help='path to output folder (default: none, current dir)')
group.add_argument('--experiment', default='', type=str, metavar='NAME', group.add_argument('--experiment', default='', type=str, metavar='NAME',
help='name of train experiment, name of sub-folder for output') help='name of train experiment, name of sub-folder for output')
group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"') help='Best metric (default: "top1"')
group.add_argument('--tta', type=int, default=0, metavar='N', group.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
group.add_argument("--local_rank", default=0, type=int) group.add_argument("--local_rank", default=0, type=int)
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False, group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch') help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False, group.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb') help='log training and validation metrics to wandb')
def _parse_args(): def _parse_args():
@ -371,8 +374,6 @@ def main():
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
if args.data and not args.data_dir:
args.data_dir = args.data
args.prefetcher = not args.no_prefetcher args.prefetcher = not args.no_prefetcher
device = utils.init_distributed_device(args) device = utils.init_distributed_device(args)
if args.distributed: if args.distributed:
@ -383,14 +384,6 @@ def main():
_logger.info(f'Training with a single process on 1 device ({args.device}).') _logger.info(f'Training with a single process on 1 device ({args.device}).')
assert args.rank >= 0 assert args.rank >= 0
if utils.is_primary(args) and args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
# resolve AMP arguments based on PyTorch / Apex availability # resolve AMP arguments based on PyTorch / Apex availability
use_amp = None use_amp = None
amp_dtype = torch.float16 amp_dtype = torch.float16
@ -432,6 +425,7 @@ def main():
bn_eps=args.bn_eps, bn_eps=args.bn_eps,
scriptable=args.torchscript, scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint, checkpoint_path=args.initial_checkpoint,
**args.model_kwargs,
) )
if args.num_classes is None: if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
@ -504,7 +498,11 @@ def main():
f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) ' f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.') f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) optimizer = create_optimizer_v2(
model,
**optimizer_kwargs(cfg=args),
**args.opt_kwargs,
)
# setup automatic mixed-precision (AMP) loss scaling and op casting # setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing amp_autocast = suppress # do nothing
@ -559,6 +557,8 @@ def main():
# NOTE: EMA model does not need to be wrapped by DDP # NOTE: EMA model does not need to be wrapped by DDP
# create the train and eval datasets # create the train and eval datasets
if args.data and not args.data_dir:
args.data_dir = args.data
dataset_train = create_dataset( dataset_train = create_dataset(
args.dataset, args.dataset,
root=args.data_dir, root=args.data_dir,
@ -712,6 +712,14 @@ def main():
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text) f.write(args_text)
if utils.is_primary(args) and args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
# setup learning rate schedule and starting epoch # setup learning rate schedule and starting epoch
updates_per_epoch = len(loader_train) updates_per_epoch = len(loader_train)
lr_scheduler, num_epochs = create_scheduler_v2( lr_scheduler, num_epochs = create_scheduler_v2(

@ -26,7 +26,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, RealLa
from timm.layers import apply_test_time_pool, set_fast_norm from timm.layers import apply_test_time_pool, set_fast_norm
from timm.models import create_model, load_checkpoint, is_model, list_models from timm.models import create_model, load_checkpoint, is_model, list_models
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \ from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \
decay_batch_step, check_batch_size_retry decay_batch_step, check_batch_size_retry, ParseKwargs
try: try:
from apex import amp from apex import amp
@ -71,6 +71,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)') metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int, parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty') metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)')
parser.add_argument('--input-size', default=None, nargs=3, type=int, parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--use-train-size', action='store_true', default=False, parser.add_argument('--use-train-size', action='store_true', default=False,
@ -123,6 +125,8 @@ parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true', parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm') help='enable experimental fast-norm')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
scripting_group = parser.add_mutually_exclusive_group() scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true', scripting_group.add_argument('--torchscript', default=False, action='store_true',
@ -181,13 +185,20 @@ def validate(args):
set_fast_norm() set_fast_norm()
# create model # create model
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
model = create_model( model = create_model(
args.model, args.model,
pretrained=args.pretrained, pretrained=args.pretrained,
num_classes=args.num_classes, num_classes=args.num_classes,
in_chans=3, in_chans=in_chans,
global_pool=args.gp, global_pool=args.gp,
scriptable=args.torchscript, scriptable=args.torchscript,
**args.model_kwargs,
) )
if args.num_classes is None: if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
@ -232,8 +243,9 @@ def validate(args):
criterion = nn.CrossEntropyLoss().to(device) criterion = nn.CrossEntropyLoss().to(device)
root_dir = args.data or args.data_dir
dataset = create_dataset( dataset = create_dataset(
root=args.data, root=root_dir,
name=args.dataset, name=args.dataset,
split=args.split, split=args.split,
download=args.dataset_download, download=args.dataset_download,
@ -389,7 +401,7 @@ def main():
if args.model == 'all': if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints # validate all models in a list of names with pretrained checkpoints
args.pretrained = True args.pretrained = True
model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*_dino']) model_names = list_models('convnext*', pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*in12k', '*_dino', '*fcmae'])
model_cfgs = [(n, '') for n in model_names] model_cfgs = [(n, '') for n in model_names]
elif not is_model(args.model): elif not is_model(args.model):
# model name doesn't exist, try as wildcard filter # model name doesn't exist, try as wildcard filter

Loading…
Cancel
Save