From e861b74cf805232204c4645f68bfb43333e4d0a5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 6 Jan 2023 12:01:43 -0800 Subject: [PATCH] Pass through --model-kwargs (and --opt-kwargs for train) from command line through to model __init__. Update some models to improve arg overlay. Cleanup along the way. --- benchmark.py | 26 +-- inference.py | 14 +- timm/models/byobnet.py | 24 ++- timm/models/convnext.py | 36 +++-- timm/models/cspnet.py | 35 ++-- timm/models/nfnet.py | 156 +++++++++++++++--- timm/models/regnet.py | 105 ++++++++++-- timm/models/resnet.py | 78 +++++---- timm/models/resnetv2.py | 73 +++++++-- timm/models/vision_transformer.py | 220 +++++++++++++------------ timm/models/vovnet.py | 74 +++++++-- timm/utils/__init__.py | 2 +- timm/utils/misc.py | 14 ++ train.py | 258 +++++++++++++++--------------- validate.py | 20 ++- 15 files changed, 775 insertions(+), 360 deletions(-) diff --git a/benchmark.py b/benchmark.py index 58435ff8..2cce3e2c 100755 --- a/benchmark.py +++ b/benchmark.py @@ -22,7 +22,7 @@ from timm.data import resolve_data_config from timm.layers import set_fast_norm from timm.models import create_model, is_model, list_models from timm.optim import create_optimizer_v2 -from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry +from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs has_apex = False try: @@ -108,12 +108,15 @@ parser.add_argument('--grad-checkpointing', action='store_true', default=False, help='Enable gradient checkpointing through model blocks/stages') parser.add_argument('--amp', action='store_true', default=False, help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.') +parser.add_argument('--amp-dtype', default='float16', type=str, + help='lower precision AMP dtype (default: float16). Overrides --precision arg if args.amp True.') parser.add_argument('--precision', default='float32', type=str, help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)') parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--fast-norm', default=False, action='store_true', help='enable experimental fast-norm') +parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) # codegen (model compilation) options scripting_group = parser.add_mutually_exclusive_group() @@ -124,7 +127,6 @@ scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None scripting_group.add_argument('--aot-autograd', default=False, action='store_true', help="Enable AOT Autograd optimization.") - # train optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', help='Optimizer (default: "sgd"') @@ -168,19 +170,21 @@ def count_params(model: nn.Module): def resolve_precision(precision: str): - assert precision in ('amp', 'float16', 'bfloat16', 'float32') - use_amp = False + assert precision in ('amp', 'amp_bfloat16', 'float16', 'bfloat16', 'float32') + amp_dtype = None # amp disabled model_dtype = torch.float32 data_dtype = torch.float32 if precision == 'amp': - use_amp = True + amp_dtype = torch.float16 + elif precision == 'amp_bfloat16': + amp_dtype = torch.bfloat16 elif precision == 'float16': model_dtype = torch.float16 data_dtype = torch.float16 elif precision == 'bfloat16': model_dtype = torch.bfloat16 data_dtype = torch.bfloat16 - return use_amp, model_dtype, data_dtype + return amp_dtype, model_dtype, data_dtype def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False): @@ -228,9 +232,12 @@ class BenchmarkRunner: self.model_name = model_name self.detail = detail self.device = device - self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision) + self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision) self.channels_last = kwargs.pop('channels_last', False) - self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=torch.float16) if self.use_amp else suppress + if self.amp_dtype is not None: + self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype) + else: + self.amp_autocast = suppress if fuser: set_jit_fuser(fuser) @@ -243,6 +250,7 @@ class BenchmarkRunner: drop_rate=kwargs.pop('drop', 0.), drop_path_rate=kwargs.pop('drop_path', None), drop_block_rate=kwargs.pop('drop_block', None), + **kwargs.pop('model_kwargs', {}), ) self.model.to( device=self.device, @@ -560,7 +568,7 @@ def _try_run( def benchmark(args): if args.amp: _logger.warning("Overriding precision to 'amp' since --amp flag set.") - args.precision = 'amp' + args.precision = 'amp' if args.amp_dtype == 'float16' else '_'.join(['amp', args.amp_dtype]) _logger.info(f'Benchmarking in {args.precision} precision. ' f'{"NHWC" if args.channels_last else "NCHW"} layout. ' f'torchscript {"enabled" if args.torchscript else "disabled"}') diff --git a/inference.py b/inference.py index 1509b323..cfbe62d1 100755 --- a/inference.py +++ b/inference.py @@ -20,7 +20,7 @@ import torch from timm.data import create_dataset, create_loader, resolve_data_config from timm.layers import apply_test_time_pool from timm.models import create_model -from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser +from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs try: from apex import amp @@ -72,6 +72,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') parser.add_argument('--img-size', default=None, type=int, metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--in-chans', type=int, default=None, metavar='N', + help='Image input channels (default: None => 3)') parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') parser.add_argument('--use-train-size', action='store_true', default=False, @@ -110,6 +112,7 @@ parser.add_argument('--amp-dtype', default='float16', type=str, help='lower precision AMP dtype (default: float16)') parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") +parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) scripting_group = parser.add_mutually_exclusive_group() scripting_group.add_argument('--torchscript', default=False, action='store_true', @@ -170,12 +173,19 @@ def main(): set_jit_fuser(args.fuser) # create model + in_chans = 3 + if args.in_chans is not None: + in_chans = args.in_chans + elif args.input_size is not None: + in_chans = args.input_size[0] + model = create_model( args.model, num_classes=args.num_classes, - in_chans=3, + in_chans=in_chans, pretrained=args.pretrained, checkpoint_path=args.checkpoint, + **args.model_kwargs, ) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 15f78044..1c7f1137 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -218,7 +218,10 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0): def interleave_blocks( - types: Tuple[str, str], d, every: Union[int, List[int]] = 1, first: bool = False, **kwargs + types: Tuple[str, str], d, + every: Union[int, List[int]] = 1, + first: bool = False, + **kwargs, ) -> Tuple[ByoBlockCfg]: """ interleave 2 block types in stack """ @@ -1587,15 +1590,32 @@ class ByobNet(nn.Module): in_chans=3, global_pool='avg', output_stride=32, - zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0., + zero_init_last=True, + **kwargs, ): + """ + + Args: + cfg (ByoModelCfg): Model architecture configuration + num_classes (int): Number of classifier classes (default: 1000) + in_chans (int): Number of input channels (default: 3) + global_pool (str): Global pooling type (default: 'avg') + output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32) + img_size (Union[int, Tuple[int]): Image size for fixed image size models (i.e. self-attn) + drop_rate (float): Dropout rate (default: 0.) + drop_path_rate (float): Stochastic depth drop-path rate (default: 0.) + zero_init_last (bool): Zero-init last weight of residual path + kwargs (dict): Extra kwargs overlayed onto cfg + """ super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate self.grad_checkpointing = False + + cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg layers = get_layer_fns(cfg) if cfg.fixed_input_size: assert img_size is not None, 'img_size argument is required for fixed input size model' diff --git a/timm/models/convnext.py b/timm/models/convnext.py index e9214429..e799a7de 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -167,7 +167,7 @@ class ConvNeXtStage(nn.Module): conv_bias=conv_bias, use_grn=use_grn, act_layer=act_layer, - norm_layer=norm_layer if conv_mlp else norm_layer_cl + norm_layer=norm_layer if conv_mlp else norm_layer_cl, )) in_chs = out_chs self.blocks = nn.Sequential(*stage_blocks) @@ -184,16 +184,6 @@ class ConvNeXtStage(nn.Module): class ConvNeXt(nn.Module): r""" ConvNeXt A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf - - Args: - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] - dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768] - drop_rate (float): Head dropout rate - drop_path_rate (float): Stochastic depth rate. Default: 0. - ls_init_value (float): Init value for Layer Scale. Default: 1e-6. - head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. """ def __init__( @@ -218,6 +208,28 @@ class ConvNeXt(nn.Module): drop_rate=0., drop_path_rate=0., ): + """ + Args: + in_chans (int): Number of input image channels (default: 3) + num_classes (int): Number of classes for classification head (default: 1000) + global_pool (str): Global pooling type (default: 'avg') + output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32) + depths (tuple(int)): Number of blocks at each stage. (default: [3, 3, 9, 3]) + dims (tuple(int)): Feature dimension at each stage. (default: [96, 192, 384, 768]) + kernel_sizes (Union[int, List[int]]: Depthwise convolution kernel-sizes for each stage (default: 7) + ls_init_value (float): Init value for Layer Scale (default: 1e-6) + stem_type (str): Type of stem (default: 'patch') + patch_size (int): Stem patch size for patch stem (default: 4) + head_init_scale (float): Init scaling value for classifier weights and biases (default: 1) + head_norm_first (bool): Apply normalization before global pool + head (default: False) + conv_mlp (bool): Use 1x1 conv in MLP, improves speed for small networks w/ chan last (default: False) + conv_bias (bool): Use bias layers w/ all convolutions (default: True) + use_grn (bool): Use Global Response Norm (ConvNeXt-V2) in MLP (default: False) + act_layer (Union[str, nn.Module]): Activation Layer + norm_layer (Union[str, nn.Module]): Normalization Layer + drop_rate (float): Head dropout rate (default: 0.) + drop_path_rate (float): Stochastic depth rate (default: 0.) + """ super().__init__() assert output_stride in (8, 16, 32) kernel_sizes = to_ntuple(4)(kernel_sizes) @@ -279,7 +291,7 @@ class ConvNeXt(nn.Module): use_grn=use_grn, act_layer=act_layer, norm_layer=norm_layer, - norm_layer_cl=norm_layer_cl + norm_layer_cl=norm_layer_cl, )) prev_chs = out_chs # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 280f929e..26ec54d9 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -12,7 +12,7 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage Hacked together by / Copyright 2020 Ross Wightman """ -from dataclasses import dataclass, asdict +from dataclasses import dataclass, asdict, replace from functools import partial from typing import Any, Dict, Optional, Tuple, Union @@ -518,7 +518,7 @@ class CrossStage(nn.Module): cross_linear=False, block_dpr=None, block_fn=BottleneckBlock, - **block_kwargs + **block_kwargs, ): super(CrossStage, self).__init__() first_dilation = first_dilation or dilation @@ -558,7 +558,7 @@ class CrossStage(nn.Module): bottle_ratio=bottle_ratio, groups=groups, drop_path=block_dpr[i] if block_dpr is not None else 0., - **block_kwargs + **block_kwargs, )) prev_chs = block_out_chs @@ -597,7 +597,7 @@ class CrossStage3(nn.Module): cross_linear=False, block_dpr=None, block_fn=BottleneckBlock, - **block_kwargs + **block_kwargs, ): super(CrossStage3, self).__init__() first_dilation = first_dilation or dilation @@ -635,7 +635,7 @@ class CrossStage3(nn.Module): bottle_ratio=bottle_ratio, groups=groups, drop_path=block_dpr[i] if block_dpr is not None else 0., - **block_kwargs + **block_kwargs, )) prev_chs = block_out_chs @@ -668,7 +668,7 @@ class DarkStage(nn.Module): avg_down=False, block_fn=BottleneckBlock, block_dpr=None, - **block_kwargs + **block_kwargs, ): super(DarkStage, self).__init__() first_dilation = first_dilation or dilation @@ -715,7 +715,7 @@ def create_csp_stem( padding='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - aa_layer=None + aa_layer=None, ): stem = nn.Sequential() feature_info = [] @@ -738,7 +738,7 @@ def create_csp_stem( stride=conv_stride, padding=padding if i == 0 else '', act_layer=act_layer, - norm_layer=norm_layer + norm_layer=norm_layer, )) stem_stride *= conv_stride prev_chs = chs @@ -800,7 +800,7 @@ def create_csp_stages( cfg: CspModelCfg, drop_path_rate: float, output_stride: int, - stem_feat: Dict[str, Any] + stem_feat: Dict[str, Any], ): cfg_dict = asdict(cfg.stages) num_stages = len(cfg.stages.depth) @@ -868,12 +868,27 @@ class CspNet(nn.Module): global_pool='avg', drop_rate=0., drop_path_rate=0., - zero_init_last=True + zero_init_last=True, + **kwargs, ): + """ + Args: + cfg (CspModelCfg): Model architecture configuration + in_chans (int): Number of input channels (default: 3) + num_classes (int): Number of classifier classes (default: 1000) + output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32) + global_pool (str): Global pooling type (default: 'avg') + drop_rate (float): Dropout rate (default: 0.) + drop_path_rate (float): Stochastic depth drop-path rate (default: 0.) + zero_init_last (bool): Zero-init last weight of residual path + kwargs (dict): Extra kwargs overlayed onto cfg + """ super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate assert output_stride in (8, 16, 32) + + cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg layer_args = dict( act_layer=cfg.act_layer, norm_layer=cfg.norm_layer, diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 48f91b35..f9a90ab3 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -17,7 +17,7 @@ Status: Hacked together by / copyright Ross Wightman, 2021. """ from collections import OrderedDict -from dataclasses import dataclass +from dataclasses import dataclass, replace from functools import partial from typing import Tuple, Optional @@ -159,11 +159,25 @@ class NfCfg: def _nfres_cfg( - depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None): + depths, + channels=(256, 512, 1024, 2048), + group_size=None, + act_layer='relu', + attn_layer=None, + attn_kwargs=None, +): attn_kwargs = attn_kwargs or {} cfg = NfCfg( - depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25, - group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs) + depths=depths, + channels=channels, + stem_type='7x7_pool', + stem_chs=64, + bottle_ratio=0.25, + group_size=group_size, + act_layer=act_layer, + attn_layer=attn_layer, + attn_kwargs=attn_kwargs, + ) return cfg @@ -171,28 +185,70 @@ def _nfreg_cfg(depths, channels=(48, 104, 208, 440)): num_features = 1280 * channels[-1] // 440 attn_kwargs = dict(rd_ratio=0.5) cfg = NfCfg( - depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25, - num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs) + depths=depths, + channels=channels, + stem_type='3x3', + group_size=8, + width_factor=0.75, + bottle_ratio=2.25, + num_features=num_features, + reg=True, + attn_layer='se', + attn_kwargs=attn_kwargs, + ) return cfg def _nfnet_cfg( - depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2., - act_layer='gelu', attn_layer='se', attn_kwargs=None): + depths, + channels=(256, 512, 1536, 1536), + group_size=128, + bottle_ratio=0.5, + feat_mult=2., + act_layer='gelu', + attn_layer='se', + attn_kwargs=None, +): num_features = int(channels[-1] * feat_mult) attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5) cfg = NfCfg( - depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size, - bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer, - attn_layer=attn_layer, attn_kwargs=attn_kwargs) + depths=depths, + channels=channels, + stem_type='deep_quad', + stem_chs=128, + group_size=group_size, + bottle_ratio=bottle_ratio, + extra_conv=True, + num_features=num_features, + act_layer=act_layer, + attn_layer=attn_layer, + attn_kwargs=attn_kwargs, + ) return cfg -def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True): +def _dm_nfnet_cfg( + depths, + channels=(256, 512, 1536, 1536), + act_layer='gelu', + skipinit=True, +): cfg = NfCfg( - depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128, - bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit, - num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=dict(rd_ratio=0.5)) + depths=depths, + channels=channels, + stem_type='deep_quad', + stem_chs=128, + group_size=128, + bottle_ratio=0.5, + extra_conv=True, + gamma_in_act=True, + same_padding=True, + skipinit=skipinit, + num_features=int(channels[-1] * 2.0), + act_layer=act_layer, + attn_layer='se', + attn_kwargs=dict(rd_ratio=0.5), + ) return cfg @@ -278,7 +334,14 @@ def act_with_gamma(act_type, gamma: float = 1.): class DownsampleAvg(nn.Module): def __init__( - self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, conv_layer=ScaledStdConv2d): + self, + in_chs, + out_chs, + stride=1, + dilation=1, + first_dilation=None, + conv_layer=ScaledStdConv2d, + ): """ AvgPool Downsampling as in 'D' ResNet variants. Support for dilation.""" super(DownsampleAvg, self).__init__() avg_stride = stride if dilation == 1 else 1 @@ -299,9 +362,26 @@ class NormFreeBlock(nn.Module): """ def __init__( - self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None, - alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False, - skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.): + self, + in_chs, + out_chs=None, + stride=1, + dilation=1, + first_dilation=None, + alpha=1.0, + beta=1.0, + bottle_ratio=0.25, + group_size=None, + ch_div=1, + reg=True, + extra_conv=False, + skipinit=False, + attn_layer=None, + attn_gain=2.0, + act_layer=None, + conv_layer=None, + drop_path_rate=0., + ): super().__init__() first_dilation = first_dilation or dilation out_chs = out_chs or in_chs @@ -316,7 +396,13 @@ class NormFreeBlock(nn.Module): if in_chs != out_chs or stride != 1 or dilation != first_dilation: self.downsample = DownsampleAvg( - in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer) + in_chs, + out_chs, + stride=stride, + dilation=dilation, + first_dilation=first_dilation, + conv_layer=conv_layer, + ) else: self.downsample = None @@ -452,14 +538,33 @@ class NormFreeNet(nn.Module): for what it is/does. Approx 8-10% throughput loss. """ def __init__( - self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, - drop_rate=0., drop_path_rate=0. + self, + cfg: NfCfg, + num_classes=1000, + in_chans=3, + global_pool='avg', + output_stride=32, + drop_rate=0., + drop_path_rate=0., + **kwargs, ): + """ + Args: + cfg (NfCfg): Model architecture configuration + num_classes (int): Number of classifier classes (default: 1000) + in_chans (int): Number of input channels (default: 3) + global_pool (str): Global pooling type (default: 'avg') + output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32) + drop_rate (float): Dropout rate (default: 0.) + drop_path_rate (float): Stochastic depth drop-path rate (default: 0.) + kwargs (dict): Extra kwargs overlayed onto cfg + """ super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate self.grad_checkpointing = False + cfg = replace(cfg, **kwargs) assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})." conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d if cfg.gamma_in_act: @@ -472,7 +577,12 @@ class NormFreeNet(nn.Module): stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div) self.stem, stem_stride, stem_feat = create_stem( - in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer) + in_chans, + stem_chs, + cfg.stem_type, + conv_layer=conv_layer, + act_layer=act_layer, + ) self.feature_info = [stem_feat] drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] diff --git a/timm/models/regnet.py b/timm/models/regnet.py index e1cc821b..9d2528f6 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -14,7 +14,7 @@ Weights from original impl have been modified Hacked together by / Copyright 2020 Ross Wightman """ import math -from dataclasses import dataclass +from dataclasses import dataclass, replace from functools import partial from typing import Optional, Union, Callable @@ -237,7 +237,15 @@ def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_la def create_shortcut( - downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None, preact=False): + downsample_type, + in_chs, + out_chs, + kernel_size, + stride, + dilation=(1, 1), + norm_layer=None, + preact=False, +): assert downsample_type in ('avg', 'conv1x1', '', None) if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact) @@ -259,9 +267,21 @@ class Bottleneck(nn.Module): """ def __init__( - self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25, - downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - drop_block=None, drop_path_rate=0.): + self, + in_chs, + out_chs, + stride=1, + dilation=(1, 1), + bottle_ratio=1, + group_size=1, + se_ratio=0.25, + downsample='conv1x1', + linear_out=False, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + drop_block=None, + drop_path_rate=0., + ): super(Bottleneck, self).__init__() act_layer = get_act_layer(act_layer) bottleneck_chs = int(round(out_chs * bottle_ratio)) @@ -307,9 +327,21 @@ class PreBottleneck(nn.Module): """ def __init__( - self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25, - downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - drop_block=None, drop_path_rate=0.): + self, + in_chs, + out_chs, + stride=1, + dilation=(1, 1), + bottle_ratio=1, + group_size=1, + se_ratio=0.25, + downsample='conv1x1', + linear_out=False, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + drop_block=None, + drop_path_rate=0., + ): super(PreBottleneck, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) bottleneck_chs = int(round(out_chs * bottle_ratio)) @@ -353,8 +385,16 @@ class RegStage(nn.Module): """Stage (sequence of blocks w/ the same output shape).""" def __init__( - self, depth, in_chs, out_chs, stride, dilation, - drop_path_rates=None, block_fn=Bottleneck, **block_kwargs): + self, + depth, + in_chs, + out_chs, + stride, + dilation, + drop_path_rates=None, + block_fn=Bottleneck, + **block_kwargs, + ): super(RegStage, self).__init__() self.grad_checkpointing = False @@ -367,8 +407,13 @@ class RegStage(nn.Module): name = "b{}".format(i + 1) self.add_module( name, block_fn( - block_in_chs, out_chs, stride=block_stride, dilation=block_dilation, - drop_path_rate=dpr, **block_kwargs) + block_in_chs, + out_chs, + stride=block_stride, + dilation=block_dilation, + drop_path_rate=dpr, + **block_kwargs, + ) ) first_dilation = dilation @@ -389,12 +434,35 @@ class RegNet(nn.Module): """ def __init__( - self, cfg: RegNetCfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', - drop_rate=0., drop_path_rate=0., zero_init_last=True): + self, + cfg: RegNetCfg, + in_chans=3, + num_classes=1000, + output_stride=32, + global_pool='avg', + drop_rate=0., + drop_path_rate=0., + zero_init_last=True, + **kwargs, + ): + """ + + Args: + cfg (RegNetCfg): Model architecture configuration + in_chans (int): Number of input channels (default: 3) + num_classes (int): Number of classifier classes (default: 1000) + output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32) + global_pool (str): Global pooling type (default: 'avg') + drop_rate (float): Dropout rate (default: 0.) + drop_path_rate (float): Stochastic depth drop-path rate (default: 0.) + zero_init_last (bool): Zero-init last weight of residual path + kwargs (dict): Extra kwargs overlayed onto cfg + """ super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate assert output_stride in (8, 16, 32) + cfg = replace(cfg, **kwargs) # update cfg with extra passed kwargs # Construct the stem stem_width = cfg.stem_width @@ -461,8 +529,12 @@ class RegNet(nn.Module): dict(zip(arg_names, params)) for params in zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_br, stage_gs, stage_dpr)] common_args = dict( - downsample=cfg.downsample, se_ratio=cfg.se_ratio, linear_out=cfg.linear_out, - act_layer=cfg.act_layer, norm_layer=cfg.norm_layer) + downsample=cfg.downsample, + se_ratio=cfg.se_ratio, + linear_out=cfg.linear_out, + act_layer=cfg.act_layer, + norm_layer=cfg.norm_layer, + ) return per_stage_args, common_args @torch.jit.ignore @@ -518,7 +590,6 @@ def _init_weights(module, name='', zero_init_last=False): def _filter_fn(state_dict): - """ convert patch embedding weight from manual patchify + linear proj to conv""" if 'classy_state_dict' in state_dict: import re state_dict = state_dict['classy_state_dict']['base_model']['model'] diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 2976c1f9..a783e3e1 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \ - create_classifier + get_act_layer, get_norm_layer, create_classifier from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model, model_entrypoint @@ -500,7 +500,14 @@ class Bottleneck(nn.Module): def downsample_conv( - in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + first_dilation=None, + norm_layer=None, +): norm_layer = norm_layer or nn.BatchNorm2d kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1 @@ -514,7 +521,14 @@ def downsample_conv( def downsample_avg( - in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + first_dilation=None, + norm_layer=None, +): norm_layer = norm_layer or nn.BatchNorm2d avg_stride = stride if dilation == 1 else 1 if stride == 1 and dilation == 1: @@ -627,31 +641,6 @@ class ResNet(nn.Module): SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64, reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block - - Parameters - ---------- - block : Block, class for the residual block. Options are BasicBlockGl, BottleneckGl. - layers : list of int, number of layers in each block - num_classes : int, default 1000, number of classification classes. - in_chans : int, default 3, number of input (color) channels. - output_stride : int, default 32, output stride of the network, 32, 16, or 8. - global_pool : str, Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' - cardinality : int, default 1, number of convolution groups for 3x3 conv in Bottleneck. - base_width : int, default 64, factor determining bottleneck channels. `planes * base_width / 64 * cardinality` - stem_width : int, default 64, number of channels in stem convolutions - stem_type : str, default '' - The type of stem: - * '', default - a single 7x7 conv with a width of stem_width - * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2 - * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2 - block_reduce_first : int, default 1 - Reduction factor for first convolution output width of residual blocks, 1 for all archs except senets, where 2 - down_kernel_size : int, default 1, kernel size of residual block downsample path, 1x1 for most, 3x3 for senets - avg_down : bool, default False, use average pooling for projection skip connection between stages/downsample. - act_layer : nn.Module, activation layer - norm_layer : nn.Module, normalization layer - aa_layer : nn.Module, anti-aliasing layer - drop_rate : float, default 0. Dropout probability before classifier, for training """ def __init__( @@ -679,6 +668,36 @@ class ResNet(nn.Module): zero_init_last=True, block_args=None, ): + """ + Args: + block (nn.Module): class for the residual block. Options are BasicBlock, Bottleneck. + layers (List[int]) : number of layers in each block + num_classes (int): number of classification classes (default 1000) + in_chans (int): number of input (color) channels. (default 3) + output_stride (int): output stride of the network, 32, 16, or 8. (default 32) + global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg') + cardinality (int): number of convolution groups for 3x3 conv in Bottleneck. (default 1) + base_width (int): bottleneck channels factor. `planes * base_width / 64 * cardinality` (default 64) + stem_width (int): number of channels in stem convolutions (default 64) + stem_type (str): The type of stem (default ''): + * '', default - a single 7x7 conv with a width of stem_width + * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2 + * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2 + replace_stem_pool (bool): replace stem max-pooling layer with a 3x3 stride-2 convolution + block_reduce_first (int): Reduction factor for first convolution output width of residual blocks, + 1 for all archs except senets, where 2 (default 1) + down_kernel_size (int): kernel size of residual block downsample path, + 1x1 for most, 3x3 for senets (default: 1) + avg_down (bool): use avg pooling for projection skip connection between stages/downsample (default False) + act_layer (str, nn.Module): activation layer + norm_layer (str, nn.Module): normalization layer + aa_layer (nn.Module): anti-aliasing layer + drop_rate (float): Dropout probability before classifier, for training (default 0.) + drop_path_rate (float): Stochastic depth drop-path rate (default 0.) + drop_block_rate (float): Drop block rate (default 0.) + zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight) + block_args (dict): Extra kwargs to pass through to block module + """ super(ResNet, self).__init__() block_args = block_args or dict() assert output_stride in (8, 16, 32) @@ -686,6 +705,9 @@ class ResNet(nn.Module): self.drop_rate = drop_rate self.grad_checkpointing = False + act_layer = get_act_layer(act_layer) + norm_layer = get_norm_layer(norm_layer) + # Stem deep_stem = 'deep' in stem_type inplanes = stem_width * 2 if deep_stem else 64 diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index a55f48ac..d696b291 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -37,7 +37,7 @@ import torch.nn as nn from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, FilterResponseNormTlu2d, \ - ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d + ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv from ._registry import register_model @@ -276,8 +276,16 @@ class Bottleneck(nn.Module): class DownsampleConv(nn.Module): def __init__( - self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True, - conv_layer=None, norm_layer=None): + self, + in_chs, + out_chs, + stride=1, + dilation=1, + first_dilation=None, + preact=True, + conv_layer=None, + norm_layer=None, + ): super(DownsampleConv, self).__init__() self.conv = conv_layer(in_chs, out_chs, 1, stride=stride) self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) @@ -288,8 +296,16 @@ class DownsampleConv(nn.Module): class DownsampleAvg(nn.Module): def __init__( - self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, - preact=True, conv_layer=None, norm_layer=None): + self, + in_chs, + out_chs, + stride=1, + dilation=1, + first_dilation=None, + preact=True, + conv_layer=None, + norm_layer=None, + ): """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" super(DownsampleAvg, self).__init__() avg_stride = stride if dilation == 1 else 1 @@ -334,9 +350,18 @@ class ResNetStage(nn.Module): drop_path_rate = block_dpr[block_idx] if block_dpr else 0. stride = stride if block_idx == 0 else 1 self.blocks.add_module(str(block_idx), block_fn( - prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups, - first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate, - **layer_kwargs, **block_kwargs)) + prev_chs, + out_chs, + stride=stride, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + first_dilation=first_dilation, + proj_layer=proj_layer, + drop_path_rate=drop_path_rate, + **layer_kwargs, + **block_kwargs, + )) prev_chs = out_chs first_dilation = dilation proj_layer = None @@ -413,21 +438,49 @@ class ResNetV2(nn.Module): avg_down=False, preact=True, act_layer=nn.ReLU, - conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32), + conv_layer=StdConv2d, drop_rate=0., drop_path_rate=0., zero_init_last=False, ): + """ + Args: + layers (List[int]) : number of layers in each block + channels (List[int]) : number of channels in each block: + num_classes (int): number of classification classes (default 1000) + in_chans (int): number of input (color) channels. (default 3) + global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg') + output_stride (int): output stride of the network, 32, 16, or 8. (default 32) + width_factor (int): channel (width) multiplication factor + stem_chs (int): stem width (default: 64) + stem_type (str): stem type (default: '' == 7x7) + avg_down (bool): average pooling in residual downsampling (default: False) + preact (bool): pre-activiation (default: True) + act_layer (Union[str, nn.Module]): activation layer + norm_layer (Union[str, nn.Module]): normalization layer + conv_layer (nn.Module): convolution module + drop_rate: classifier dropout rate (default: 0.) + drop_path_rate: stochastic depth rate (default: 0.) + zero_init_last: zero-init last weight in residual path (default: False) + """ super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate wf = width_factor + norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer) + act_layer = get_act_layer(act_layer) self.feature_info = [] stem_chs = make_div(stem_chs * wf) self.stem = create_resnetv2_stem( - in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer) + in_chans, + stem_chs, + stem_type, + preact, + conv_layer=conv_layer, + norm_layer=norm_layer, + ) stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm' self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat)) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index d6865549..9441a3b2 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1152,8 +1152,8 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs): def vit_tiny_patch16_224(pretrained=False, **kwargs): """ ViT-Tiny (Vit-Ti/16) """ - model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) - model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) + model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1161,8 +1161,8 @@ def vit_tiny_patch16_224(pretrained=False, **kwargs): def vit_tiny_patch16_384(pretrained=False, **kwargs): """ ViT-Tiny (Vit-Ti/16) @ 384x384. """ - model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) - model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) + model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1170,8 +1170,8 @@ def vit_tiny_patch16_384(pretrained=False, **kwargs): def vit_small_patch32_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/32) """ - model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1179,8 +1179,8 @@ def vit_small_patch32_224(pretrained=False, **kwargs): def vit_small_patch32_384(pretrained=False, **kwargs): """ ViT-Small (ViT-S/32) at 384x384. """ - model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1188,8 +1188,8 @@ def vit_small_patch32_384(pretrained=False, **kwargs): def vit_small_patch16_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1197,8 +1197,8 @@ def vit_small_patch16_224(pretrained=False, **kwargs): def vit_small_patch16_384(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1206,8 +1206,8 @@ def vit_small_patch16_384(pretrained=False, **kwargs): def vit_small_patch8_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/8) """ - model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1216,8 +1216,8 @@ def vit_base_patch32_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1226,8 +1226,8 @@ def vit_base_patch32_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1236,8 +1236,8 @@ def vit_base_patch16_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1246,8 +1246,8 @@ def vit_base_patch16_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1256,8 +1256,8 @@ def vit_base_patch8_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1265,8 +1265,8 @@ def vit_base_patch8_224(pretrained=False, **kwargs): def vit_large_patch32_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. """ - model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1275,8 +1275,8 @@ def vit_large_patch32_384(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1285,8 +1285,8 @@ def vit_large_patch16_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1295,8 +1295,8 @@ def vit_large_patch16_384(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1304,8 +1304,8 @@ def vit_large_patch16_384(pretrained=False, **kwargs): def vit_large_patch14_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/14) """ - model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1313,8 +1313,8 @@ def vit_large_patch14_224(pretrained=False, **kwargs): def vit_huge_patch14_224(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). """ - model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16) + model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1322,8 +1322,8 @@ def vit_huge_patch14_224(pretrained=False, **kwargs): def vit_giant_patch14_224(pretrained=False, **kwargs): """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ - model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16) + model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1331,8 +1331,9 @@ def vit_giant_patch14_224(pretrained=False, **kwargs): def vit_gigantic_patch14_224(pretrained=False, **kwargs): """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ - model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16) + model = _create_vision_transformer( + 'vit_gigantic_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1341,8 +1342,9 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False) + model = _create_vision_transformer( + 'vit_base_patch16_224_miil', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1352,8 +1354,9 @@ def vit_medium_patch16_gap_240(pretrained=False, **kwargs): """ model_kwargs = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, - global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) - model = _create_vision_transformer('vit_medium_patch16_gap_240', pretrained=pretrained, **model_kwargs) + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) + model = _create_vision_transformer( + 'vit_medium_patch16_gap_240', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1363,8 +1366,9 @@ def vit_medium_patch16_gap_256(pretrained=False, **kwargs): """ model_kwargs = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, - global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) - model = _create_vision_transformer('vit_medium_patch16_gap_256', pretrained=pretrained, **model_kwargs) + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) + model = _create_vision_transformer( + 'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1374,8 +1378,9 @@ def vit_medium_patch16_gap_384(pretrained=False, **kwargs): """ model_kwargs = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, - global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) - model = _create_vision_transformer('vit_medium_patch16_gap_384', pretrained=pretrained, **model_kwargs) + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) + model = _create_vision_transformer( + 'vit_medium_patch16_gap_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1384,9 +1389,9 @@ def vit_base_patch16_gap_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256 """ model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, - global_pool=kwargs.get('global_pool', 'avg'), fc_norm=False, **kwargs) - model = _create_vision_transformer('vit_base_patch16_gap_224', pretrained=pretrained, **model_kwargs) + patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False) + model = _create_vision_transformer( + 'vit_base_patch16_gap_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1395,8 +1400,9 @@ def vit_base_patch32_clip_224(pretrained=False, **kwargs): """ ViT-B/32 CLIP image tower @ 224x224 """ model_kwargs = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_base_patch32_clip_224', pretrained=pretrained, **model_kwargs) + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1405,8 +1411,9 @@ def vit_base_patch32_clip_384(pretrained=False, **kwargs): """ ViT-B/32 CLIP image tower @ 384x384 """ model_kwargs = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_base_patch32_clip_384', pretrained=pretrained, **model_kwargs) + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch32_clip_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1415,8 +1422,9 @@ def vit_base_patch32_clip_448(pretrained=False, **kwargs): """ ViT-B/32 CLIP image tower @ 448x448 """ model_kwargs = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_base_patch32_clip_448', pretrained=pretrained, **model_kwargs) + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch32_clip_448', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1424,9 +1432,9 @@ def vit_base_patch32_clip_448(pretrained=False, **kwargs): def vit_base_patch16_clip_224(pretrained=False, **kwargs): """ ViT-B/16 CLIP image tower """ - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_base_patch16_clip_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1434,9 +1442,9 @@ def vit_base_patch16_clip_224(pretrained=False, **kwargs): def vit_base_patch16_clip_384(pretrained=False, **kwargs): """ ViT-B/16 CLIP image tower @ 384x384 """ - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_base_patch16_clip_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch16_clip_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1444,9 +1452,9 @@ def vit_base_patch16_clip_384(pretrained=False, **kwargs): def vit_large_patch14_clip_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/14) CLIP image tower """ - model_kwargs = dict( - patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_large_patch14_clip_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1454,9 +1462,9 @@ def vit_large_patch14_clip_224(pretrained=False, **kwargs): def vit_large_patch14_clip_336(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 """ - model_kwargs = dict( - patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_large_patch14_clip_336', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1464,9 +1472,9 @@ def vit_large_patch14_clip_336(pretrained=False, **kwargs): def vit_huge_patch14_clip_224(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) CLIP image tower. """ - model_kwargs = dict( - patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_huge_patch14_clip_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1474,9 +1482,9 @@ def vit_huge_patch14_clip_224(pretrained=False, **kwargs): def vit_huge_patch14_clip_336(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336 """ - model_kwargs = dict( - patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_huge_patch14_clip_336', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_huge_patch14_clip_336', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1486,9 +1494,9 @@ def vit_giant_patch14_clip_224(pretrained=False, **kwargs): Pretrained weights from CLIP image tower. """ model_kwargs = dict( - patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, - pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_giant_patch14_clip_224', pretrained=pretrained, **model_kwargs) + patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_giant_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1498,8 +1506,9 @@ def vit_giant_patch14_clip_224(pretrained=False, **kwargs): def vit_base_patch32_plus_256(pretrained=False, **kwargs): """ ViT-Base (ViT-B/32+) """ - model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) - model = _create_vision_transformer('vit_base_patch32_plus_256', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5) + model = _create_vision_transformer( + 'vit_base_patch32_plus_256', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1507,8 +1516,9 @@ def vit_base_patch32_plus_256(pretrained=False, **kwargs): def vit_base_patch16_plus_240(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16+) """ - model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) - model = _create_vision_transformer('vit_base_patch16_plus_240', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5) + model = _create_vision_transformer( + 'vit_base_patch16_plus_240', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1517,9 +1527,10 @@ def vit_base_patch16_rpn_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ residual post-norm """ model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False, - block_fn=ResPostBlock, global_pool=kwargs.pop('global_pool', 'avg'), **kwargs) - model = _create_vision_transformer('vit_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs) + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, + class_token=False, block_fn=ResPostBlock, global_pool='avg') + model = _create_vision_transformer( + 'vit_base_patch16_rpn_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1529,8 +1540,9 @@ def vit_small_patch16_36x1_224(pretrained=False, **kwargs): Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs) - model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5) + model = _create_vision_transformer( + 'vit_small_patch16_36x1_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1541,8 +1553,9 @@ def vit_small_patch16_18x2_224(pretrained=False, **kwargs): Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. """ model_kwargs = dict( - patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs) - model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs) + patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock) + model = _create_vision_transformer( + 'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1551,27 +1564,26 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs): """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 """ - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) - model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock) + model = _create_vision_transformer( + 'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def eva_large_patch14_196(pretrained=False, **kwargs): """ EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain""" - model_kwargs = dict( - patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) - model = _create_vision_transformer('eva_large_patch14_196', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg') + model = _create_vision_transformer( + 'eva_large_patch14_196', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def eva_large_patch14_336(pretrained=False, **kwargs): """ EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain""" - model_kwargs = dict( - patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) - model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg') + model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1579,8 +1591,8 @@ def eva_large_patch14_336(pretrained=False, **kwargs): def flexivit_small(pretrained=False, **kwargs): """ FlexiViT-Small """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, **kwargs) - model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True) + model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1588,8 +1600,8 @@ def flexivit_small(pretrained=False, **kwargs): def flexivit_base(pretrained=False, **kwargs): """ FlexiViT-Base """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, **kwargs) - model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True) + model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1597,6 +1609,6 @@ def flexivit_base(pretrained=False, **kwargs): def flexivit_large(pretrained=False, **kwargs): """ FlexiViT-Large """ - model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, **kwargs) - model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True) + model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index bf0e4f89..8aea5802 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -181,8 +181,18 @@ class SequentialAppendList(nn.Sequential): class OsaBlock(nn.Module): def __init__( - self, in_chs, mid_chs, out_chs, layer_per_block, residual=False, - depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None): + self, + in_chs, + mid_chs, + out_chs, + layer_per_block, + residual=False, + depthwise=False, + attn='', + norm_layer=BatchNormAct2d, + act_layer=nn.ReLU, + drop_path=None, + ): super(OsaBlock, self).__init__() self.residual = residual @@ -232,9 +242,20 @@ class OsaBlock(nn.Module): class OsaStage(nn.Module): def __init__( - self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True, - residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, - drop_path_rates=None): + self, + in_chs, + mid_chs, + out_chs, + block_per_stage, + layer_per_block, + downsample=True, + residual=True, + depthwise=False, + attn='ese', + norm_layer=BatchNormAct2d, + act_layer=nn.ReLU, + drop_path_rates=None, + ): super(OsaStage, self).__init__() self.grad_checkpointing = False @@ -270,16 +291,38 @@ class OsaStage(nn.Module): class VovNet(nn.Module): def __init__( - self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4, - output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.): - """ VovNet (v2) + self, + cfg, + in_chans=3, + num_classes=1000, + global_pool='avg', + output_stride=32, + norm_layer=BatchNormAct2d, + act_layer=nn.ReLU, + drop_rate=0., + drop_path_rate=0., + **kwargs, + ): + """ + Args: + cfg (dict): Model architecture configuration + in_chans (int): Number of input channels (default: 3) + num_classes (int): Number of classifier classes (default: 1000) + global_pool (str): Global pooling type (default: 'avg') + output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32) + norm_layer (Union[str, nn.Module]): normalization layer + act_layer (Union[str, nn.Module]): activation layer + drop_rate (float): Dropout rate (default: 0.) + drop_path_rate (float): Stochastic depth drop-path rate (default: 0.) + kwargs (dict): Extra kwargs overlayed onto cfg """ super(VovNet, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate - assert stem_stride in (4, 2) assert output_stride == 32 # FIXME support dilation + cfg = dict(cfg, **kwargs) + stem_stride = cfg.get("stem_stride", 4) stem_chs = cfg["stem_chs"] stage_conv_chs = cfg["stage_conv_chs"] stage_out_chs = cfg["stage_out_chs"] @@ -307,9 +350,15 @@ class VovNet(nn.Module): for i in range(4): # num_stages downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4 stages += [OsaStage( - in_ch_list[i], stage_conv_chs[i], stage_out_chs[i], block_per_stage[i], layer_per_block, - downsample=downsample, drop_path_rates=stage_dpr[i], **stage_args) - ] + in_ch_list[i], + stage_conv_chs[i], + stage_out_chs[i], + block_per_stage[i], + layer_per_block, + downsample=downsample, + drop_path_rates=stage_dpr[i], + **stage_args, + )] self.num_features = stage_out_chs[i] current_stride *= 2 if downsample else 1 self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')] @@ -324,7 +373,6 @@ class VovNet(nn.Module): elif isinstance(m, nn.Linear): nn.init.zeros_(m.bias) - @torch.jit.ignore def group_matcher(self, coarse=False): return dict( diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index a9ff0c78..7727adff 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -8,7 +8,7 @@ from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\ from .jit import set_jit_legacy, set_jit_fuser from .log import setup_default_logging, FormatterNoInfo from .metrics import AverageMeter, accuracy -from .misc import natural_key, add_bool_arg +from .misc import natural_key, add_bool_arg, ParseKwargs from .model import unwrap_model, get_state_dict, freeze, unfreeze from .model_ema import ModelEma, ModelEmaV2 from .random import random_seed diff --git a/timm/utils/misc.py b/timm/utils/misc.py index 39c0097c..326a50f7 100644 --- a/timm/utils/misc.py +++ b/timm/utils/misc.py @@ -2,6 +2,8 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import argparse +import ast import re @@ -16,3 +18,15 @@ def add_bool_arg(parser, name, default=False, help=''): group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) parser.set_defaults(**{dest_name: default}) + + +class ParseKwargs(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + kw = {} + for value in values: + key, value = value.split('=') + try: + kw[key] = ast.literal_eval(value) + except ValueError: + kw[key] = str(value) # fallback to string (avoid need to escape on command line) + setattr(namespace, self.dest, kw) diff --git a/train.py b/train.py index e51d7c90..9f450ab8 100755 --- a/train.py +++ b/train.py @@ -89,56 +89,58 @@ parser.add_argument('--data-dir', metavar='DIR', parser.add_argument('--dataset', metavar='NAME', default='', help='dataset type + name ("/") (default: ImageFolder or ImageTar if empty)') group.add_argument('--train-split', metavar='NAME', default='train', - help='dataset train split (default: train)') + help='dataset train split (default: train)') group.add_argument('--val-split', metavar='NAME', default='validation', - help='dataset validation split (default: validation)') + help='dataset validation split (default: validation)') group.add_argument('--dataset-download', action='store_true', default=False, - help='Allow download of dataset for torch/ and tfds/ datasets that support it.') + help='Allow download of dataset for torch/ and tfds/ datasets that support it.') group.add_argument('--class-map', default='', type=str, metavar='FILENAME', - help='path to class to idx mapping file (default: "")') + help='path to class to idx mapping file (default: "")') # Model parameters group = parser.add_argument_group('Model parameters') group.add_argument('--model', default='resnet50', type=str, metavar='MODEL', - help='Name of model to train (default: "resnet50")') + help='Name of model to train (default: "resnet50")') group.add_argument('--pretrained', action='store_true', default=False, - help='Start with pretrained version of specified network (if avail)') + help='Start with pretrained version of specified network (if avail)') group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', - help='Initialize model from this checkpoint (default: none)') + help='Initialize model from this checkpoint (default: none)') group.add_argument('--resume', default='', type=str, metavar='PATH', - help='Resume full model and optimizer state from checkpoint (default: none)') + help='Resume full model and optimizer state from checkpoint (default: none)') group.add_argument('--no-resume-opt', action='store_true', default=False, - help='prevent resume of optimizer state when resuming model') + help='prevent resume of optimizer state when resuming model') group.add_argument('--num-classes', type=int, default=None, metavar='N', - help='number of label classes (Model default if None)') + help='number of label classes (Model default if None)') group.add_argument('--gp', default=None, type=str, metavar='POOL', - help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') + help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') group.add_argument('--img-size', type=int, default=None, metavar='N', - help='Image size (default: None => model default)') + help='Image size (default: None => model default)') group.add_argument('--in-chans', type=int, default=None, metavar='N', - help='Image input channels (default: None => 3)') + help='Image input channels (default: None => 3)') group.add_argument('--input-size', default=None, nargs=3, type=int, - metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') + metavar='N N N', + help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') group.add_argument('--crop-pct', default=None, type=float, - metavar='N', help='Input image center crop percent (for validation only)') + metavar='N', help='Input image center crop percent (for validation only)') group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', - help='Override mean pixel value of dataset') + help='Override mean pixel value of dataset') group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', - help='Override std deviation of dataset') + help='Override std deviation of dataset') group.add_argument('--interpolation', default='', type=str, metavar='NAME', - help='Image resize interpolation type (overrides model)') + help='Image resize interpolation type (overrides model)') group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', - help='Input batch size for training (default: 128)') + help='Input batch size for training (default: 128)') group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', - help='Validation batch size override (default: None)') + help='Validation batch size override (default: None)') group.add_argument('--channels-last', action='store_true', default=False, - help='Use channels_last memory layout') + help='Use channels_last memory layout') group.add_argument('--fuser', default='', type=str, - help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") + help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") group.add_argument('--grad-checkpointing', action='store_true', default=False, - help='Enable gradient checkpointing through model blocks/stages') + help='Enable gradient checkpointing through model blocks/stages') group.add_argument('--fast-norm', default=False, action='store_true', - help='enable experimental fast-norm') + help='enable experimental fast-norm') +group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs) scripting_group = group.add_mutually_exclusive_group() scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', @@ -151,199 +153,200 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true # Optimizer parameters group = parser.add_argument_group('Optimizer parameters') group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', - help='Optimizer (default: "sgd")') + help='Optimizer (default: "sgd")') group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', - help='Optimizer Epsilon (default: None, use opt default)') + help='Optimizer Epsilon (default: None, use opt default)') group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', - help='Optimizer Betas (default: None, use opt default)') + help='Optimizer Betas (default: None, use opt default)') group.add_argument('--momentum', type=float, default=0.9, metavar='M', - help='Optimizer momentum (default: 0.9)') + help='Optimizer momentum (default: 0.9)') group.add_argument('--weight-decay', type=float, default=2e-5, - help='weight decay (default: 2e-5)') + help='weight decay (default: 2e-5)') group.add_argument('--clip-grad', type=float, default=None, metavar='NORM', - help='Clip gradient norm (default: None, no clipping)') + help='Clip gradient norm (default: None, no clipping)') group.add_argument('--clip-mode', type=str, default='norm', - help='Gradient clipping mode. One of ("norm", "value", "agc")') + help='Gradient clipping mode. One of ("norm", "value", "agc")') group.add_argument('--layer-decay', type=float, default=None, - help='layer-wise learning rate decay (default: None)') + help='layer-wise learning rate decay (default: None)') +group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs) # Learning rate schedule parameters group = parser.add_argument_group('Learning rate schedule parameters') group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER', - help='LR scheduler (default: "step"') + help='LR scheduler (default: "step"') group.add_argument('--sched-on-updates', action='store_true', default=False, - help='Apply LR scheduler step on update instead of epoch end.') + help='Apply LR scheduler step on update instead of epoch end.') group.add_argument('--lr', type=float, default=None, metavar='LR', - help='learning rate, overrides lr-base if set (default: None)') + help='learning rate, overrides lr-base if set (default: None)') group.add_argument('--lr-base', type=float, default=0.1, metavar='LR', - help='base learning rate: lr = lr_base * global_batch_size / base_size') + help='base learning rate: lr = lr_base * global_batch_size / base_size') group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV', - help='base learning rate batch size (divisor, default: 256).') + help='base learning rate batch size (divisor, default: 256).') group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE', - help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)') + help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)') group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', - help='learning rate noise on/off epoch percentages') + help='learning rate noise on/off epoch percentages') group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', - help='learning rate noise limit percent (default: 0.67)') + help='learning rate noise limit percent (default: 0.67)') group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', - help='learning rate noise std-dev (default: 1.0)') + help='learning rate noise std-dev (default: 1.0)') group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', - help='learning rate cycle len multiplier (default: 1.0)') + help='learning rate cycle len multiplier (default: 1.0)') group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', - help='amount to decay each learning rate cycle (default: 0.5)') + help='amount to decay each learning rate cycle (default: 0.5)') group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', - help='learning rate cycle limit, cycles enabled if > 1') + help='learning rate cycle limit, cycles enabled if > 1') group.add_argument('--lr-k-decay', type=float, default=1.0, - help='learning rate k-decay for cosine/poly (default: 1.0)') + help='learning rate k-decay for cosine/poly (default: 1.0)') group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR', - help='warmup learning rate (default: 1e-5)') + help='warmup learning rate (default: 1e-5)') group.add_argument('--min-lr', type=float, default=0, metavar='LR', - help='lower lr bound for cyclic schedulers that hit 0 (default: 0)') + help='lower lr bound for cyclic schedulers that hit 0 (default: 0)') group.add_argument('--epochs', type=int, default=300, metavar='N', - help='number of epochs to train (default: 300)') + help='number of epochs to train (default: 300)') group.add_argument('--epoch-repeats', type=float, default=0., metavar='N', - help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') + help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') group.add_argument('--start-epoch', default=None, type=int, metavar='N', - help='manual epoch number (useful on restarts)') + help='manual epoch number (useful on restarts)') group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES", - help='list of decay epoch indices for multistep lr. must be increasing') + help='list of decay epoch indices for multistep lr. must be increasing') group.add_argument('--decay-epochs', type=float, default=90, metavar='N', - help='epoch interval to decay LR') + help='epoch interval to decay LR') group.add_argument('--warmup-epochs', type=int, default=5, metavar='N', - help='epochs to warmup LR, if scheduler supports') + help='epochs to warmup LR, if scheduler supports') group.add_argument('--warmup-prefix', action='store_true', default=False, - help='Exclude warmup period from decay schedule.'), + help='Exclude warmup period from decay schedule.'), group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N', - help='epochs to cooldown LR at min_lr, after cyclic schedule ends') + help='epochs to cooldown LR at min_lr, after cyclic schedule ends') group.add_argument('--patience-epochs', type=int, default=10, metavar='N', - help='patience epochs for Plateau LR scheduler (default: 10)') + help='patience epochs for Plateau LR scheduler (default: 10)') group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', - help='LR decay rate (default: 0.1)') + help='LR decay rate (default: 0.1)') # Augmentation & regularization parameters group = parser.add_argument_group('Augmentation and regularization parameters') group.add_argument('--no-aug', action='store_true', default=False, - help='Disable all training augmentation, override other train aug args') + help='Disable all training augmentation, override other train aug args') group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', - help='Random resize scale (default: 0.08 1.0)') -group.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', - help='Random resize aspect ratio (default: 0.75 1.33)') + help='Random resize scale (default: 0.08 1.0)') +group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', + help='Random resize aspect ratio (default: 0.75 1.33)') group.add_argument('--hflip', type=float, default=0.5, - help='Horizontal flip training aug probability') + help='Horizontal flip training aug probability') group.add_argument('--vflip', type=float, default=0., - help='Vertical flip training aug probability') + help='Vertical flip training aug probability') group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', - help='Color jitter factor (default: 0.4)') + help='Color jitter factor (default: 0.4)') group.add_argument('--aa', type=str, default=None, metavar='NAME', - help='Use AutoAugment policy. "v0" or "original". (default: None)'), + help='Use AutoAugment policy. "v0" or "original". (default: None)'), group.add_argument('--aug-repeats', type=float, default=0, - help='Number of augmentation repetitions (distributed training only) (default: 0)') + help='Number of augmentation repetitions (distributed training only) (default: 0)') group.add_argument('--aug-splits', type=int, default=0, - help='Number of augmentation splits (default: 0, valid: 0 or >=2)') + help='Number of augmentation splits (default: 0, valid: 0 or >=2)') group.add_argument('--jsd-loss', action='store_true', default=False, - help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') + help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') group.add_argument('--bce-loss', action='store_true', default=False, - help='Enable BCE loss w/ Mixup/CutMix use.') + help='Enable BCE loss w/ Mixup/CutMix use.') group.add_argument('--bce-target-thresh', type=float, default=None, - help='Threshold for binarizing softened BCE targets (default: None, disabled)') + help='Threshold for binarizing softened BCE targets (default: None, disabled)') group.add_argument('--reprob', type=float, default=0., metavar='PCT', - help='Random erase prob (default: 0.)') + help='Random erase prob (default: 0.)') group.add_argument('--remode', type=str, default='pixel', - help='Random erase mode (default: "pixel")') + help='Random erase mode (default: "pixel")') group.add_argument('--recount', type=int, default=1, - help='Random erase count (default: 1)') + help='Random erase count (default: 1)') group.add_argument('--resplit', action='store_true', default=False, - help='Do not random erase first (clean) augmentation split') + help='Do not random erase first (clean) augmentation split') group.add_argument('--mixup', type=float, default=0.0, - help='mixup alpha, mixup enabled if > 0. (default: 0.)') + help='mixup alpha, mixup enabled if > 0. (default: 0.)') group.add_argument('--cutmix', type=float, default=0.0, - help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') + help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, - help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') + help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') group.add_argument('--mixup-prob', type=float, default=1.0, - help='Probability of performing mixup or cutmix when either/both is enabled') + help='Probability of performing mixup or cutmix when either/both is enabled') group.add_argument('--mixup-switch-prob', type=float, default=0.5, - help='Probability of switching to cutmix when both mixup and cutmix enabled') + help='Probability of switching to cutmix when both mixup and cutmix enabled') group.add_argument('--mixup-mode', type=str, default='batch', - help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') + help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', - help='Turn off mixup after this epoch, disabled if 0 (default: 0)') + help='Turn off mixup after this epoch, disabled if 0 (default: 0)') group.add_argument('--smoothing', type=float, default=0.1, - help='Label smoothing (default: 0.1)') + help='Label smoothing (default: 0.1)') group.add_argument('--train-interpolation', type=str, default='random', - help='Training interpolation (random, bilinear, bicubic default: "random")') + help='Training interpolation (random, bilinear, bicubic default: "random")') group.add_argument('--drop', type=float, default=0.0, metavar='PCT', - help='Dropout rate (default: 0.)') + help='Dropout rate (default: 0.)') group.add_argument('--drop-connect', type=float, default=None, metavar='PCT', - help='Drop connect rate, DEPRECATED, use drop-path (default: None)') + help='Drop connect rate, DEPRECATED, use drop-path (default: None)') group.add_argument('--drop-path', type=float, default=None, metavar='PCT', - help='Drop path rate (default: None)') + help='Drop path rate (default: None)') group.add_argument('--drop-block', type=float, default=None, metavar='PCT', - help='Drop block rate (default: None)') + help='Drop block rate (default: None)') # Batch norm parameters (only works with gen_efficientnet based models currently) group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.') group.add_argument('--bn-momentum', type=float, default=None, - help='BatchNorm momentum override (if not None)') + help='BatchNorm momentum override (if not None)') group.add_argument('--bn-eps', type=float, default=None, - help='BatchNorm epsilon override (if not None)') + help='BatchNorm epsilon override (if not None)') group.add_argument('--sync-bn', action='store_true', - help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') + help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') group.add_argument('--dist-bn', type=str, default='reduce', - help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') + help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') group.add_argument('--split-bn', action='store_true', - help='Enable separate BN layers per augmentation split.') + help='Enable separate BN layers per augmentation split.') # Model Exponential Moving Average group = parser.add_argument_group('Model exponential moving average parameters') group.add_argument('--model-ema', action='store_true', default=False, - help='Enable tracking moving average of model weights') + help='Enable tracking moving average of model weights') group.add_argument('--model-ema-force-cpu', action='store_true', default=False, - help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') + help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') group.add_argument('--model-ema-decay', type=float, default=0.9998, - help='decay factor for model weights moving average (default: 0.9998)') + help='decay factor for model weights moving average (default: 0.9998)') # Misc group = parser.add_argument_group('Miscellaneous parameters') group.add_argument('--seed', type=int, default=42, metavar='S', - help='random seed (default: 42)') + help='random seed (default: 42)') group.add_argument('--worker-seeding', type=str, default='all', - help='worker seed mode (default: all)') + help='worker seed mode (default: all)') group.add_argument('--log-interval', type=int, default=50, metavar='N', - help='how many batches to wait before logging training status') + help='how many batches to wait before logging training status') group.add_argument('--recovery-interval', type=int, default=0, metavar='N', - help='how many batches to wait before writing recovery checkpoint') + help='how many batches to wait before writing recovery checkpoint') group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', - help='number of checkpoints to keep (default: 10)') + help='number of checkpoints to keep (default: 10)') group.add_argument('-j', '--workers', type=int, default=4, metavar='N', - help='how many training processes to use (default: 4)') + help='how many training processes to use (default: 4)') group.add_argument('--save-images', action='store_true', default=False, - help='save images of input bathes every log interval for debugging') + help='save images of input bathes every log interval for debugging') group.add_argument('--amp', action='store_true', default=False, - help='use NVIDIA Apex AMP or Native AMP for mixed precision training') + help='use NVIDIA Apex AMP or Native AMP for mixed precision training') group.add_argument('--amp-dtype', default='float16', type=str, - help='lower precision AMP dtype (default: float16)') + help='lower precision AMP dtype (default: float16)') group.add_argument('--amp-impl', default='native', type=str, - help='AMP impl to use, "native" or "apex" (default: native)') + help='AMP impl to use, "native" or "apex" (default: native)') group.add_argument('--no-ddp-bb', action='store_true', default=False, - help='Force broadcast buffers for native DDP to off.') + help='Force broadcast buffers for native DDP to off.') group.add_argument('--pin-mem', action='store_true', default=False, - help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') group.add_argument('--no-prefetcher', action='store_true', default=False, - help='disable fast prefetcher') + help='disable fast prefetcher') group.add_argument('--output', default='', type=str, metavar='PATH', - help='path to output folder (default: none, current dir)') + help='path to output folder (default: none, current dir)') group.add_argument('--experiment', default='', type=str, metavar='NAME', - help='name of train experiment, name of sub-folder for output') + help='name of train experiment, name of sub-folder for output') group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', - help='Best metric (default: "top1"') + help='Best metric (default: "top1"') group.add_argument('--tta', type=int, default=0, metavar='N', - help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') + help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') group.add_argument("--local_rank", default=0, type=int) group.add_argument('--use-multi-epochs-loader', action='store_true', default=False, - help='use the multi-epochs-loader to save time at the beginning of every epoch') + help='use the multi-epochs-loader to save time at the beginning of every epoch') group.add_argument('--log-wandb', action='store_true', default=False, - help='log training and validation metrics to wandb') + help='log training and validation metrics to wandb') def _parse_args(): @@ -371,8 +374,6 @@ def main(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True - if args.data and not args.data_dir: - args.data_dir = args.data args.prefetcher = not args.no_prefetcher device = utils.init_distributed_device(args) if args.distributed: @@ -383,14 +384,6 @@ def main(): _logger.info(f'Training with a single process on 1 device ({args.device}).') assert args.rank >= 0 - if utils.is_primary(args) and args.log_wandb: - if has_wandb: - wandb.init(project=args.experiment, config=args) - else: - _logger.warning( - "You've requested to log metrics to wandb but package not found. " - "Metrics not being logged to wandb, try `pip install wandb`") - # resolve AMP arguments based on PyTorch / Apex availability use_amp = None amp_dtype = torch.float16 @@ -432,6 +425,7 @@ def main(): bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint, + **args.model_kwargs, ) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' @@ -504,7 +498,11 @@ def main(): f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) ' f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.') - optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) + optimizer = create_optimizer_v2( + model, + **optimizer_kwargs(cfg=args), + **args.opt_kwargs, + ) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing @@ -559,6 +557,8 @@ def main(): # NOTE: EMA model does not need to be wrapped by DDP # create the train and eval datasets + if args.data and not args.data_dir: + args.data_dir = args.data dataset_train = create_dataset( args.dataset, root=args.data_dir, @@ -712,6 +712,14 @@ def main(): with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) + if utils.is_primary(args) and args.log_wandb: + if has_wandb: + wandb.init(project=args.experiment, config=args) + else: + _logger.warning( + "You've requested to log metrics to wandb but package not found. " + "Metrics not being logged to wandb, try `pip install wandb`") + # setup learning rate schedule and starting epoch updates_per_epoch = len(loader_train) lr_scheduler, num_epochs = create_scheduler_v2( diff --git a/validate.py b/validate.py index 4669fbac..b606103d 100755 --- a/validate.py +++ b/validate.py @@ -26,7 +26,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, RealLa from timm.layers import apply_test_time_pool, set_fast_norm from timm.models import create_model, load_checkpoint, is_model, list_models from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \ - decay_batch_step, check_batch_size_retry + decay_batch_step, check_batch_size_retry, ParseKwargs try: from apex import amp @@ -71,6 +71,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') parser.add_argument('--img-size', default=None, type=int, metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--in-chans', type=int, default=None, metavar='N', + help='Image input channels (default: None => 3)') parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') parser.add_argument('--use-train-size', action='store_true', default=False, @@ -123,6 +125,8 @@ parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--fast-norm', default=False, action='store_true', help='enable experimental fast-norm') +parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) + scripting_group = parser.add_mutually_exclusive_group() scripting_group.add_argument('--torchscript', default=False, action='store_true', @@ -181,13 +185,20 @@ def validate(args): set_fast_norm() # create model + in_chans = 3 + if args.in_chans is not None: + in_chans = args.in_chans + elif args.input_size is not None: + in_chans = args.input_size[0] + model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, - in_chans=3, + in_chans=in_chans, global_pool=args.gp, scriptable=args.torchscript, + **args.model_kwargs, ) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' @@ -232,8 +243,9 @@ def validate(args): criterion = nn.CrossEntropyLoss().to(device) + root_dir = args.data or args.data_dir dataset = create_dataset( - root=args.data, + root=root_dir, name=args.dataset, split=args.split, download=args.dataset_download, @@ -389,7 +401,7 @@ def main(): if args.model == 'all': # validate all models in a list of names with pretrained checkpoints args.pretrained = True - model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*_dino']) + model_names = list_models('convnext*', pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*in12k', '*_dino', '*fcmae']) model_cfgs = [(n, '') for n in model_names] elif not is_model(args.model): # model name doesn't exist, try as wildcard filter