Pass through --model-kwargs (and --opt-kwargs for train) from command line through to model __init__. Update some models to improve arg overlay. Cleanup along the way.

pull/1627/head
Ross Wightman 2 years ago committed by Ross Wightman
parent add3fb864e
commit e861b74cf8

@ -22,7 +22,7 @@ from timm.data import resolve_data_config
from timm.layers import set_fast_norm
from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs
has_apex = False
try:
@ -108,12 +108,15 @@ parser.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages')
parser.add_argument('--amp', action='store_true', default=False,
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
parser.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16). Overrides --precision arg if args.amp True.')
parser.add_argument('--precision', default='float32', type=str,
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
# codegen (model compilation) options
scripting_group = parser.add_mutually_exclusive_group()
@ -124,7 +127,6 @@ scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd optimization.")
# train optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
@ -168,19 +170,21 @@ def count_params(model: nn.Module):
def resolve_precision(precision: str):
assert precision in ('amp', 'float16', 'bfloat16', 'float32')
use_amp = False
assert precision in ('amp', 'amp_bfloat16', 'float16', 'bfloat16', 'float32')
amp_dtype = None # amp disabled
model_dtype = torch.float32
data_dtype = torch.float32
if precision == 'amp':
use_amp = True
amp_dtype = torch.float16
elif precision == 'amp_bfloat16':
amp_dtype = torch.bfloat16
elif precision == 'float16':
model_dtype = torch.float16
data_dtype = torch.float16
elif precision == 'bfloat16':
model_dtype = torch.bfloat16
data_dtype = torch.bfloat16
return use_amp, model_dtype, data_dtype
return amp_dtype, model_dtype, data_dtype
def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
@ -228,9 +232,12 @@ class BenchmarkRunner:
self.model_name = model_name
self.detail = detail
self.device = device
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.channels_last = kwargs.pop('channels_last', False)
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=torch.float16) if self.use_amp else suppress
if self.amp_dtype is not None:
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype)
else:
self.amp_autocast = suppress
if fuser:
set_jit_fuser(fuser)
@ -243,6 +250,7 @@ class BenchmarkRunner:
drop_rate=kwargs.pop('drop', 0.),
drop_path_rate=kwargs.pop('drop_path', None),
drop_block_rate=kwargs.pop('drop_block', None),
**kwargs.pop('model_kwargs', {}),
)
self.model.to(
device=self.device,
@ -560,7 +568,7 @@ def _try_run(
def benchmark(args):
if args.amp:
_logger.warning("Overriding precision to 'amp' since --amp flag set.")
args.precision = 'amp'
args.precision = 'amp' if args.amp_dtype == 'float16' else '_'.join(['amp', args.amp_dtype])
_logger.info(f'Benchmarking in {args.precision} precision. '
f'{"NHWC" if args.channels_last else "NCHW"} layout. '
f'torchscript {"enabled" if args.torchscript else "disabled"}')

@ -20,7 +20,7 @@ import torch
from timm.data import create_dataset, create_loader, resolve_data_config
from timm.layers import apply_test_time_pool
from timm.models import create_model
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
try:
from apex import amp
@ -72,6 +72,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--use-train-size', action='store_true', default=False,
@ -110,6 +112,7 @@ parser.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true',
@ -170,12 +173,19 @@ def main():
set_jit_fuser(args.fuser)
# create model
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
model = create_model(
args.model,
num_classes=args.num_classes,
in_chans=3,
in_chans=in_chans,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint,
**args.model_kwargs,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'

@ -218,7 +218,10 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
def interleave_blocks(
types: Tuple[str, str], d, every: Union[int, List[int]] = 1, first: bool = False, **kwargs
types: Tuple[str, str], d,
every: Union[int, List[int]] = 1,
first: bool = False,
**kwargs,
) -> Tuple[ByoBlockCfg]:
""" interleave 2 block types in stack
"""
@ -1587,15 +1590,32 @@ class ByobNet(nn.Module):
in_chans=3,
global_pool='avg',
output_stride=32,
zero_init_last=True,
img_size=None,
drop_rate=0.,
drop_path_rate=0.,
zero_init_last=True,
**kwargs,
):
"""
Args:
cfg (ByoModelCfg): Model architecture configuration
num_classes (int): Number of classifier classes (default: 1000)
in_chans (int): Number of input channels (default: 3)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
img_size (Union[int, Tuple[int]): Image size for fixed image size models (i.e. self-attn)
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
zero_init_last (bool): Zero-init last weight of residual path
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.grad_checkpointing = False
cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
layers = get_layer_fns(cfg)
if cfg.fixed_input_size:
assert img_size is not None, 'img_size argument is required for fixed input size model'

@ -167,7 +167,7 @@ class ConvNeXtStage(nn.Module):
conv_bias=conv_bias,
use_grn=use_grn,
act_layer=act_layer,
norm_layer=norm_layer if conv_mlp else norm_layer_cl
norm_layer=norm_layer if conv_mlp else norm_layer_cl,
))
in_chs = out_chs
self.blocks = nn.Sequential(*stage_blocks)
@ -184,16 +184,6 @@ class ConvNeXtStage(nn.Module):
class ConvNeXt(nn.Module):
r""" ConvNeXt
A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_rate (float): Head dropout rate
drop_path_rate (float): Stochastic depth rate. Default: 0.
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(
@ -218,6 +208,28 @@ class ConvNeXt(nn.Module):
drop_rate=0.,
drop_path_rate=0.,
):
"""
Args:
in_chans (int): Number of input image channels (default: 3)
num_classes (int): Number of classes for classification head (default: 1000)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
depths (tuple(int)): Number of blocks at each stage. (default: [3, 3, 9, 3])
dims (tuple(int)): Feature dimension at each stage. (default: [96, 192, 384, 768])
kernel_sizes (Union[int, List[int]]: Depthwise convolution kernel-sizes for each stage (default: 7)
ls_init_value (float): Init value for Layer Scale (default: 1e-6)
stem_type (str): Type of stem (default: 'patch')
patch_size (int): Stem patch size for patch stem (default: 4)
head_init_scale (float): Init scaling value for classifier weights and biases (default: 1)
head_norm_first (bool): Apply normalization before global pool + head (default: False)
conv_mlp (bool): Use 1x1 conv in MLP, improves speed for small networks w/ chan last (default: False)
conv_bias (bool): Use bias layers w/ all convolutions (default: True)
use_grn (bool): Use Global Response Norm (ConvNeXt-V2) in MLP (default: False)
act_layer (Union[str, nn.Module]): Activation Layer
norm_layer (Union[str, nn.Module]): Normalization Layer
drop_rate (float): Head dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth rate (default: 0.)
"""
super().__init__()
assert output_stride in (8, 16, 32)
kernel_sizes = to_ntuple(4)(kernel_sizes)
@ -279,7 +291,7 @@ class ConvNeXt(nn.Module):
use_grn=use_grn,
act_layer=act_layer,
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl
norm_layer_cl=norm_layer_cl,
))
prev_chs = out_chs
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2

@ -12,7 +12,7 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage
Hacked together by / Copyright 2020 Ross Wightman
"""
from dataclasses import dataclass, asdict
from dataclasses import dataclass, asdict, replace
from functools import partial
from typing import Any, Dict, Optional, Tuple, Union
@ -518,7 +518,7 @@ class CrossStage(nn.Module):
cross_linear=False,
block_dpr=None,
block_fn=BottleneckBlock,
**block_kwargs
**block_kwargs,
):
super(CrossStage, self).__init__()
first_dilation = first_dilation or dilation
@ -558,7 +558,7 @@ class CrossStage(nn.Module):
bottle_ratio=bottle_ratio,
groups=groups,
drop_path=block_dpr[i] if block_dpr is not None else 0.,
**block_kwargs
**block_kwargs,
))
prev_chs = block_out_chs
@ -597,7 +597,7 @@ class CrossStage3(nn.Module):
cross_linear=False,
block_dpr=None,
block_fn=BottleneckBlock,
**block_kwargs
**block_kwargs,
):
super(CrossStage3, self).__init__()
first_dilation = first_dilation or dilation
@ -635,7 +635,7 @@ class CrossStage3(nn.Module):
bottle_ratio=bottle_ratio,
groups=groups,
drop_path=block_dpr[i] if block_dpr is not None else 0.,
**block_kwargs
**block_kwargs,
))
prev_chs = block_out_chs
@ -668,7 +668,7 @@ class DarkStage(nn.Module):
avg_down=False,
block_fn=BottleneckBlock,
block_dpr=None,
**block_kwargs
**block_kwargs,
):
super(DarkStage, self).__init__()
first_dilation = first_dilation or dilation
@ -715,7 +715,7 @@ def create_csp_stem(
padding='',
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
aa_layer=None
aa_layer=None,
):
stem = nn.Sequential()
feature_info = []
@ -738,7 +738,7 @@ def create_csp_stem(
stride=conv_stride,
padding=padding if i == 0 else '',
act_layer=act_layer,
norm_layer=norm_layer
norm_layer=norm_layer,
))
stem_stride *= conv_stride
prev_chs = chs
@ -800,7 +800,7 @@ def create_csp_stages(
cfg: CspModelCfg,
drop_path_rate: float,
output_stride: int,
stem_feat: Dict[str, Any]
stem_feat: Dict[str, Any],
):
cfg_dict = asdict(cfg.stages)
num_stages = len(cfg.stages.depth)
@ -868,12 +868,27 @@ class CspNet(nn.Module):
global_pool='avg',
drop_rate=0.,
drop_path_rate=0.,
zero_init_last=True
zero_init_last=True,
**kwargs,
):
"""
Args:
cfg (CspModelCfg): Model architecture configuration
in_chans (int): Number of input channels (default: 3)
num_classes (int): Number of classifier classes (default: 1000)
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
global_pool (str): Global pooling type (default: 'avg')
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
zero_init_last (bool): Zero-init last weight of residual path
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
assert output_stride in (8, 16, 32)
cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
layer_args = dict(
act_layer=cfg.act_layer,
norm_layer=cfg.norm_layer,

@ -17,7 +17,7 @@ Status:
Hacked together by / copyright Ross Wightman, 2021.
"""
from collections import OrderedDict
from dataclasses import dataclass
from dataclasses import dataclass, replace
from functools import partial
from typing import Tuple, Optional
@ -159,11 +159,25 @@ class NfCfg:
def _nfres_cfg(
depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None):
depths,
channels=(256, 512, 1024, 2048),
group_size=None,
act_layer='relu',
attn_layer=None,
attn_kwargs=None,
):
attn_kwargs = attn_kwargs or {}
cfg = NfCfg(
depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25,
group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs)
depths=depths,
channels=channels,
stem_type='7x7_pool',
stem_chs=64,
bottle_ratio=0.25,
group_size=group_size,
act_layer=act_layer,
attn_layer=attn_layer,
attn_kwargs=attn_kwargs,
)
return cfg
@ -171,28 +185,70 @@ def _nfreg_cfg(depths, channels=(48, 104, 208, 440)):
num_features = 1280 * channels[-1] // 440
attn_kwargs = dict(rd_ratio=0.5)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25,
num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs)
depths=depths,
channels=channels,
stem_type='3x3',
group_size=8,
width_factor=0.75,
bottle_ratio=2.25,
num_features=num_features,
reg=True,
attn_layer='se',
attn_kwargs=attn_kwargs,
)
return cfg
def _nfnet_cfg(
depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2.,
act_layer='gelu', attn_layer='se', attn_kwargs=None):
depths,
channels=(256, 512, 1536, 1536),
group_size=128,
bottle_ratio=0.5,
feat_mult=2.,
act_layer='gelu',
attn_layer='se',
attn_kwargs=None,
):
num_features = int(channels[-1] * feat_mult)
attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size,
bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer,
attn_layer=attn_layer, attn_kwargs=attn_kwargs)
depths=depths,
channels=channels,
stem_type='deep_quad',
stem_chs=128,
group_size=group_size,
bottle_ratio=bottle_ratio,
extra_conv=True,
num_features=num_features,
act_layer=act_layer,
attn_layer=attn_layer,
attn_kwargs=attn_kwargs,
)
return cfg
def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True):
def _dm_nfnet_cfg(
depths,
channels=(256, 512, 1536, 1536),
act_layer='gelu',
skipinit=True,
):
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128,
bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit,
num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=dict(rd_ratio=0.5))
depths=depths,
channels=channels,
stem_type='deep_quad',
stem_chs=128,
group_size=128,
bottle_ratio=0.5,
extra_conv=True,
gamma_in_act=True,
same_padding=True,
skipinit=skipinit,
num_features=int(channels[-1] * 2.0),
act_layer=act_layer,
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.5),
)
return cfg
@ -278,7 +334,14 @@ def act_with_gamma(act_type, gamma: float = 1.):
class DownsampleAvg(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, conv_layer=ScaledStdConv2d):
self,
in_chs,
out_chs,
stride=1,
dilation=1,
first_dilation=None,
conv_layer=ScaledStdConv2d,
):
""" AvgPool Downsampling as in 'D' ResNet variants. Support for dilation."""
super(DownsampleAvg, self).__init__()
avg_stride = stride if dilation == 1 else 1
@ -299,9 +362,26 @@ class NormFreeBlock(nn.Module):
"""
def __init__(
self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None,
alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False,
skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.):
self,
in_chs,
out_chs=None,
stride=1,
dilation=1,
first_dilation=None,
alpha=1.0,
beta=1.0,
bottle_ratio=0.25,
group_size=None,
ch_div=1,
reg=True,
extra_conv=False,
skipinit=False,
attn_layer=None,
attn_gain=2.0,
act_layer=None,
conv_layer=None,
drop_path_rate=0.,
):
super().__init__()
first_dilation = first_dilation or dilation
out_chs = out_chs or in_chs
@ -316,7 +396,13 @@ class NormFreeBlock(nn.Module):
if in_chs != out_chs or stride != 1 or dilation != first_dilation:
self.downsample = DownsampleAvg(
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer)
in_chs,
out_chs,
stride=stride,
dilation=dilation,
first_dilation=first_dilation,
conv_layer=conv_layer,
)
else:
self.downsample = None
@ -452,14 +538,33 @@ class NormFreeNet(nn.Module):
for what it is/does. Approx 8-10% throughput loss.
"""
def __init__(
self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
drop_rate=0., drop_path_rate=0.
self,
cfg: NfCfg,
num_classes=1000,
in_chans=3,
global_pool='avg',
output_stride=32,
drop_rate=0.,
drop_path_rate=0.,
**kwargs,
):
"""
Args:
cfg (NfCfg): Model architecture configuration
num_classes (int): Number of classifier classes (default: 1000)
in_chans (int): Number of input channels (default: 3)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.grad_checkpointing = False
cfg = replace(cfg, **kwargs)
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
if cfg.gamma_in_act:
@ -472,7 +577,12 @@ class NormFreeNet(nn.Module):
stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)
self.stem, stem_stride, stem_feat = create_stem(
in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer)
in_chans,
stem_chs,
cfg.stem_type,
conv_layer=conv_layer,
act_layer=act_layer,
)
self.feature_info = [stem_feat]
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]

@ -14,7 +14,7 @@ Weights from original impl have been modified
Hacked together by / Copyright 2020 Ross Wightman
"""
import math
from dataclasses import dataclass
from dataclasses import dataclass, replace
from functools import partial
from typing import Optional, Union, Callable
@ -237,7 +237,15 @@ def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_la
def create_shortcut(
downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None, preact=False):
downsample_type,
in_chs,
out_chs,
kernel_size,
stride,
dilation=(1, 1),
norm_layer=None,
preact=False,
):
assert downsample_type in ('avg', 'conv1x1', '', None)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact)
@ -259,9 +267,21 @@ class Bottleneck(nn.Module):
"""
def __init__(
self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
stride=1,
dilation=(1, 1),
bottle_ratio=1,
group_size=1,
se_ratio=0.25,
downsample='conv1x1',
linear_out=False,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
drop_block=None,
drop_path_rate=0.,
):
super(Bottleneck, self).__init__()
act_layer = get_act_layer(act_layer)
bottleneck_chs = int(round(out_chs * bottle_ratio))
@ -307,9 +327,21 @@ class PreBottleneck(nn.Module):
"""
def __init__(
self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
stride=1,
dilation=(1, 1),
bottle_ratio=1,
group_size=1,
se_ratio=0.25,
downsample='conv1x1',
linear_out=False,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
drop_block=None,
drop_path_rate=0.,
):
super(PreBottleneck, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
bottleneck_chs = int(round(out_chs * bottle_ratio))
@ -353,8 +385,16 @@ class RegStage(nn.Module):
"""Stage (sequence of blocks w/ the same output shape)."""
def __init__(
self, depth, in_chs, out_chs, stride, dilation,
drop_path_rates=None, block_fn=Bottleneck, **block_kwargs):
self,
depth,
in_chs,
out_chs,
stride,
dilation,
drop_path_rates=None,
block_fn=Bottleneck,
**block_kwargs,
):
super(RegStage, self).__init__()
self.grad_checkpointing = False
@ -367,8 +407,13 @@ class RegStage(nn.Module):
name = "b{}".format(i + 1)
self.add_module(
name, block_fn(
block_in_chs, out_chs, stride=block_stride, dilation=block_dilation,
drop_path_rate=dpr, **block_kwargs)
block_in_chs,
out_chs,
stride=block_stride,
dilation=block_dilation,
drop_path_rate=dpr,
**block_kwargs,
)
)
first_dilation = dilation
@ -389,12 +434,35 @@ class RegNet(nn.Module):
"""
def __init__(
self, cfg: RegNetCfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg',
drop_rate=0., drop_path_rate=0., zero_init_last=True):
self,
cfg: RegNetCfg,
in_chans=3,
num_classes=1000,
output_stride=32,
global_pool='avg',
drop_rate=0.,
drop_path_rate=0.,
zero_init_last=True,
**kwargs,
):
"""
Args:
cfg (RegNetCfg): Model architecture configuration
in_chans (int): Number of input channels (default: 3)
num_classes (int): Number of classifier classes (default: 1000)
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
global_pool (str): Global pooling type (default: 'avg')
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
zero_init_last (bool): Zero-init last weight of residual path
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
assert output_stride in (8, 16, 32)
cfg = replace(cfg, **kwargs) # update cfg with extra passed kwargs
# Construct the stem
stem_width = cfg.stem_width
@ -461,8 +529,12 @@ class RegNet(nn.Module):
dict(zip(arg_names, params)) for params in
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_br, stage_gs, stage_dpr)]
common_args = dict(
downsample=cfg.downsample, se_ratio=cfg.se_ratio, linear_out=cfg.linear_out,
act_layer=cfg.act_layer, norm_layer=cfg.norm_layer)
downsample=cfg.downsample,
se_ratio=cfg.se_ratio,
linear_out=cfg.linear_out,
act_layer=cfg.act_layer,
norm_layer=cfg.norm_layer,
)
return per_stage_args, common_args
@torch.jit.ignore
@ -518,7 +590,6 @@ def _init_weights(module, name='', zero_init_last=False):
def _filter_fn(state_dict):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
if 'classy_state_dict' in state_dict:
import re
state_dict = state_dict['classy_state_dict']['base_model']['model']

@ -16,7 +16,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \
create_classifier
get_act_layer, get_norm_layer, create_classifier
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model, model_entrypoint
@ -500,7 +500,14 @@ class Bottleneck(nn.Module):
def downsample_conv(
in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
first_dilation=None,
norm_layer=None,
):
norm_layer = norm_layer or nn.BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
@ -514,7 +521,14 @@ def downsample_conv(
def downsample_avg(
in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
first_dilation=None,
norm_layer=None,
):
norm_layer = norm_layer or nn.BatchNorm2d
avg_stride = stride if dilation == 1 else 1
if stride == 1 and dilation == 1:
@ -627,31 +641,6 @@ class ResNet(nn.Module):
SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64,
reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block
Parameters
----------
block : Block, class for the residual block. Options are BasicBlockGl, BottleneckGl.
layers : list of int, number of layers in each block
num_classes : int, default 1000, number of classification classes.
in_chans : int, default 3, number of input (color) channels.
output_stride : int, default 32, output stride of the network, 32, 16, or 8.
global_pool : str, Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
cardinality : int, default 1, number of convolution groups for 3x3 conv in Bottleneck.
base_width : int, default 64, factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
stem_width : int, default 64, number of channels in stem convolutions
stem_type : str, default ''
The type of stem:
* '', default - a single 7x7 conv with a width of stem_width
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
block_reduce_first : int, default 1
Reduction factor for first convolution output width of residual blocks, 1 for all archs except senets, where 2
down_kernel_size : int, default 1, kernel size of residual block downsample path, 1x1 for most, 3x3 for senets
avg_down : bool, default False, use average pooling for projection skip connection between stages/downsample.
act_layer : nn.Module, activation layer
norm_layer : nn.Module, normalization layer
aa_layer : nn.Module, anti-aliasing layer
drop_rate : float, default 0. Dropout probability before classifier, for training
"""
def __init__(
@ -679,6 +668,36 @@ class ResNet(nn.Module):
zero_init_last=True,
block_args=None,
):
"""
Args:
block (nn.Module): class for the residual block. Options are BasicBlock, Bottleneck.
layers (List[int]) : number of layers in each block
num_classes (int): number of classification classes (default 1000)
in_chans (int): number of input (color) channels. (default 3)
output_stride (int): output stride of the network, 32, 16, or 8. (default 32)
global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg')
cardinality (int): number of convolution groups for 3x3 conv in Bottleneck. (default 1)
base_width (int): bottleneck channels factor. `planes * base_width / 64 * cardinality` (default 64)
stem_width (int): number of channels in stem convolutions (default 64)
stem_type (str): The type of stem (default ''):
* '', default - a single 7x7 conv with a width of stem_width
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
replace_stem_pool (bool): replace stem max-pooling layer with a 3x3 stride-2 convolution
block_reduce_first (int): Reduction factor for first convolution output width of residual blocks,
1 for all archs except senets, where 2 (default 1)
down_kernel_size (int): kernel size of residual block downsample path,
1x1 for most, 3x3 for senets (default: 1)
avg_down (bool): use avg pooling for projection skip connection between stages/downsample (default False)
act_layer (str, nn.Module): activation layer
norm_layer (str, nn.Module): normalization layer
aa_layer (nn.Module): anti-aliasing layer
drop_rate (float): Dropout probability before classifier, for training (default 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default 0.)
drop_block_rate (float): Drop block rate (default 0.)
zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight)
block_args (dict): Extra kwargs to pass through to block module
"""
super(ResNet, self).__init__()
block_args = block_args or dict()
assert output_stride in (8, 16, 32)
@ -686,6 +705,9 @@ class ResNet(nn.Module):
self.drop_rate = drop_rate
self.grad_checkpointing = False
act_layer = get_act_layer(act_layer)
norm_layer = get_norm_layer(norm_layer)
# Stem
deep_stem = 'deep' in stem_type
inplanes = stem_width * 2 if deep_stem else 64

@ -37,7 +37,7 @@ import torch.nn as nn
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, FilterResponseNormTlu2d, \
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv
from ._registry import register_model
@ -276,8 +276,16 @@ class Bottleneck(nn.Module):
class DownsampleConv(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True,
conv_layer=None, norm_layer=None):
self,
in_chs,
out_chs,
stride=1,
dilation=1,
first_dilation=None,
preact=True,
conv_layer=None,
norm_layer=None,
):
super(DownsampleConv, self).__init__()
self.conv = conv_layer(in_chs, out_chs, 1, stride=stride)
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
@ -288,8 +296,16 @@ class DownsampleConv(nn.Module):
class DownsampleAvg(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None,
preact=True, conv_layer=None, norm_layer=None):
self,
in_chs,
out_chs,
stride=1,
dilation=1,
first_dilation=None,
preact=True,
conv_layer=None,
norm_layer=None,
):
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
super(DownsampleAvg, self).__init__()
avg_stride = stride if dilation == 1 else 1
@ -334,9 +350,18 @@ class ResNetStage(nn.Module):
drop_path_rate = block_dpr[block_idx] if block_dpr else 0.
stride = stride if block_idx == 0 else 1
self.blocks.add_module(str(block_idx), block_fn(
prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups,
first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate,
**layer_kwargs, **block_kwargs))
prev_chs,
out_chs,
stride=stride,
dilation=dilation,
bottle_ratio=bottle_ratio,
groups=groups,
first_dilation=first_dilation,
proj_layer=proj_layer,
drop_path_rate=drop_path_rate,
**layer_kwargs,
**block_kwargs,
))
prev_chs = out_chs
first_dilation = dilation
proj_layer = None
@ -413,21 +438,49 @@ class ResNetV2(nn.Module):
avg_down=False,
preact=True,
act_layer=nn.ReLU,
conv_layer=StdConv2d,
norm_layer=partial(GroupNormAct, num_groups=32),
conv_layer=StdConv2d,
drop_rate=0.,
drop_path_rate=0.,
zero_init_last=False,
):
"""
Args:
layers (List[int]) : number of layers in each block
channels (List[int]) : number of channels in each block:
num_classes (int): number of classification classes (default 1000)
in_chans (int): number of input (color) channels. (default 3)
global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg')
output_stride (int): output stride of the network, 32, 16, or 8. (default 32)
width_factor (int): channel (width) multiplication factor
stem_chs (int): stem width (default: 64)
stem_type (str): stem type (default: '' == 7x7)
avg_down (bool): average pooling in residual downsampling (default: False)
preact (bool): pre-activiation (default: True)
act_layer (Union[str, nn.Module]): activation layer
norm_layer (Union[str, nn.Module]): normalization layer
conv_layer (nn.Module): convolution module
drop_rate: classifier dropout rate (default: 0.)
drop_path_rate: stochastic depth rate (default: 0.)
zero_init_last: zero-init last weight in residual path (default: False)
"""
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
wf = width_factor
norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
act_layer = get_act_layer(act_layer)
self.feature_info = []
stem_chs = make_div(stem_chs * wf)
self.stem = create_resnetv2_stem(
in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
in_chans,
stem_chs,
stem_type,
preact,
conv_layer=conv_layer,
norm_layer=norm_layer,
)
stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm'
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))

@ -1152,8 +1152,8 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
def vit_tiny_patch16_224(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16)
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1161,8 +1161,8 @@ def vit_tiny_patch16_224(pretrained=False, **kwargs):
def vit_tiny_patch16_384(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16) @ 384x384.
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1170,8 +1170,8 @@ def vit_tiny_patch16_384(pretrained=False, **kwargs):
def vit_small_patch32_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/32)
"""
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1179,8 +1179,8 @@ def vit_small_patch32_224(pretrained=False, **kwargs):
def vit_small_patch32_384(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/32) at 384x384.
"""
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1188,8 +1188,8 @@ def vit_small_patch32_384(pretrained=False, **kwargs):
def vit_small_patch16_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1197,8 +1197,8 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
def vit_small_patch16_384(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1206,8 +1206,8 @@ def vit_small_patch16_384(pretrained=False, **kwargs):
def vit_small_patch8_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/8)
"""
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1216,8 +1216,8 @@ def vit_base_patch32_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1226,8 +1226,8 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1236,8 +1236,8 @@ def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1246,8 +1246,8 @@ def vit_base_patch16_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1256,8 +1256,8 @@ def vit_base_patch8_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1265,8 +1265,8 @@ def vit_base_patch8_224(pretrained=False, **kwargs):
def vit_large_patch32_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
"""
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1275,8 +1275,8 @@ def vit_large_patch32_384(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1285,8 +1285,8 @@ def vit_large_patch16_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1295,8 +1295,8 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1304,8 +1304,8 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
def vit_large_patch14_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14)
"""
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1313,8 +1313,8 @@ def vit_large_patch14_224(pretrained=False, **kwargs):
def vit_huge_patch14_224(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
"""
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16)
model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1322,8 +1322,8 @@ def vit_huge_patch14_224(pretrained=False, **kwargs):
def vit_giant_patch14_224(pretrained=False, **kwargs):
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
"""
model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16)
model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1331,8 +1331,9 @@ def vit_giant_patch14_224(pretrained=False, **kwargs):
def vit_gigantic_patch14_224(pretrained=False, **kwargs):
""" ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
"""
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
model = _create_vision_transformer(
'vit_gigantic_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1341,8 +1342,9 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False)
model = _create_vision_transformer(
'vit_base_patch16_224_miil', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1352,8 +1354,9 @@ def vit_medium_patch16_gap_240(pretrained=False, **kwargs):
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_240', pretrained=pretrained, **model_kwargs)
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
model = _create_vision_transformer(
'vit_medium_patch16_gap_240', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1363,8 +1366,9 @@ def vit_medium_patch16_gap_256(pretrained=False, **kwargs):
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_256', pretrained=pretrained, **model_kwargs)
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
model = _create_vision_transformer(
'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1374,8 +1378,9 @@ def vit_medium_patch16_gap_384(pretrained=False, **kwargs):
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_384', pretrained=pretrained, **model_kwargs)
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
model = _create_vision_transformer(
'vit_medium_patch16_gap_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1384,9 +1389,9 @@ def vit_base_patch16_gap_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False,
global_pool=kwargs.get('global_pool', 'avg'), fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_base_patch16_gap_224', pretrained=pretrained, **model_kwargs)
patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
model = _create_vision_transformer(
'vit_base_patch16_gap_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1395,8 +1400,9 @@ def vit_base_patch32_clip_224(pretrained=False, **kwargs):
""" ViT-B/32 CLIP image tower @ 224x224
"""
model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_base_patch32_clip_224', pretrained=pretrained, **model_kwargs)
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1405,8 +1411,9 @@ def vit_base_patch32_clip_384(pretrained=False, **kwargs):
""" ViT-B/32 CLIP image tower @ 384x384
"""
model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_base_patch32_clip_384', pretrained=pretrained, **model_kwargs)
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_base_patch32_clip_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1415,8 +1422,9 @@ def vit_base_patch32_clip_448(pretrained=False, **kwargs):
""" ViT-B/32 CLIP image tower @ 448x448
"""
model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_base_patch32_clip_448', pretrained=pretrained, **model_kwargs)
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_base_patch32_clip_448', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1424,9 +1432,9 @@ def vit_base_patch32_clip_448(pretrained=False, **kwargs):
def vit_base_patch16_clip_224(pretrained=False, **kwargs):
""" ViT-B/16 CLIP image tower
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_base_patch16_clip_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1434,9 +1442,9 @@ def vit_base_patch16_clip_224(pretrained=False, **kwargs):
def vit_base_patch16_clip_384(pretrained=False, **kwargs):
""" ViT-B/16 CLIP image tower @ 384x384
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_base_patch16_clip_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_base_patch16_clip_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1444,9 +1452,9 @@ def vit_base_patch16_clip_384(pretrained=False, **kwargs):
def vit_large_patch14_clip_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14) CLIP image tower
"""
model_kwargs = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_large_patch14_clip_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1454,9 +1462,9 @@ def vit_large_patch14_clip_224(pretrained=False, **kwargs):
def vit_large_patch14_clip_336(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14) CLIP image tower @ 336x336
"""
model_kwargs = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_large_patch14_clip_336', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1464,9 +1472,9 @@ def vit_large_patch14_clip_336(pretrained=False, **kwargs):
def vit_huge_patch14_clip_224(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) CLIP image tower.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_clip_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1474,9 +1482,9 @@ def vit_huge_patch14_clip_224(pretrained=False, **kwargs):
def vit_huge_patch14_clip_336(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_clip_336', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_huge_patch14_clip_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1486,9 +1494,9 @@ def vit_giant_patch14_clip_224(pretrained=False, **kwargs):
Pretrained weights from CLIP image tower.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16,
pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_giant_patch14_clip_224', pretrained=pretrained, **model_kwargs)
patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_giant_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1498,8 +1506,9 @@ def vit_giant_patch14_clip_224(pretrained=False, **kwargs):
def vit_base_patch32_plus_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32+)
"""
model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
model = _create_vision_transformer('vit_base_patch32_plus_256', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5)
model = _create_vision_transformer(
'vit_base_patch32_plus_256', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1507,8 +1516,9 @@ def vit_base_patch32_plus_256(pretrained=False, **kwargs):
def vit_base_patch16_plus_240(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16+)
"""
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
model = _create_vision_transformer('vit_base_patch16_plus_240', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5)
model = _create_vision_transformer(
'vit_base_patch16_plus_240', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1517,9 +1527,10 @@ def vit_base_patch16_rpn_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ residual post-norm
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False,
block_fn=ResPostBlock, global_pool=kwargs.pop('global_pool', 'avg'), **kwargs)
model = _create_vision_transformer('vit_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
class_token=False, block_fn=ResPostBlock, global_pool='avg')
model = _create_vision_transformer(
'vit_base_patch16_rpn_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1529,8 +1540,9 @@ def vit_small_patch16_36x1_224(pretrained=False, **kwargs):
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs)
model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5)
model = _create_vision_transformer(
'vit_small_patch16_36x1_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1541,8 +1553,9 @@ def vit_small_patch16_18x2_224(pretrained=False, **kwargs):
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock)
model = _create_vision_transformer(
'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1551,27 +1564,26 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
""" ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock)
model = _create_vision_transformer(
'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva_large_patch14_196(pretrained=False, **kwargs):
""" EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain"""
model_kwargs = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs)
model = _create_vision_transformer('eva_large_patch14_196', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg')
model = _create_vision_transformer(
'eva_large_patch14_196', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva_large_patch14_336(pretrained=False, **kwargs):
""" EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain"""
model_kwargs = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs)
model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg')
model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1579,8 +1591,8 @@ def eva_large_patch14_336(pretrained=False, **kwargs):
def flexivit_small(pretrained=False, **kwargs):
""" FlexiViT-Small
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, **kwargs)
model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True)
model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1588,8 +1600,8 @@ def flexivit_small(pretrained=False, **kwargs):
def flexivit_base(pretrained=False, **kwargs):
""" FlexiViT-Base
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, **kwargs)
model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1597,6 +1609,6 @@ def flexivit_base(pretrained=False, **kwargs):
def flexivit_large(pretrained=False, **kwargs):
""" FlexiViT-Large
"""
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, **kwargs)
model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True)
model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model

@ -181,8 +181,18 @@ class SequentialAppendList(nn.Sequential):
class OsaBlock(nn.Module):
def __init__(
self, in_chs, mid_chs, out_chs, layer_per_block, residual=False,
depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None):
self,
in_chs,
mid_chs,
out_chs,
layer_per_block,
residual=False,
depthwise=False,
attn='',
norm_layer=BatchNormAct2d,
act_layer=nn.ReLU,
drop_path=None,
):
super(OsaBlock, self).__init__()
self.residual = residual
@ -232,9 +242,20 @@ class OsaBlock(nn.Module):
class OsaStage(nn.Module):
def __init__(
self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True,
residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU,
drop_path_rates=None):
self,
in_chs,
mid_chs,
out_chs,
block_per_stage,
layer_per_block,
downsample=True,
residual=True,
depthwise=False,
attn='ese',
norm_layer=BatchNormAct2d,
act_layer=nn.ReLU,
drop_path_rates=None,
):
super(OsaStage, self).__init__()
self.grad_checkpointing = False
@ -270,16 +291,38 @@ class OsaStage(nn.Module):
class VovNet(nn.Module):
def __init__(
self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4,
output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.):
""" VovNet (v2)
self,
cfg,
in_chans=3,
num_classes=1000,
global_pool='avg',
output_stride=32,
norm_layer=BatchNormAct2d,
act_layer=nn.ReLU,
drop_rate=0.,
drop_path_rate=0.,
**kwargs,
):
"""
Args:
cfg (dict): Model architecture configuration
in_chans (int): Number of input channels (default: 3)
num_classes (int): Number of classifier classes (default: 1000)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
norm_layer (Union[str, nn.Module]): normalization layer
act_layer (Union[str, nn.Module]): activation layer
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super(VovNet, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
assert stem_stride in (4, 2)
assert output_stride == 32 # FIXME support dilation
cfg = dict(cfg, **kwargs)
stem_stride = cfg.get("stem_stride", 4)
stem_chs = cfg["stem_chs"]
stage_conv_chs = cfg["stage_conv_chs"]
stage_out_chs = cfg["stage_out_chs"]
@ -307,9 +350,15 @@ class VovNet(nn.Module):
for i in range(4): # num_stages
downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4
stages += [OsaStage(
in_ch_list[i], stage_conv_chs[i], stage_out_chs[i], block_per_stage[i], layer_per_block,
downsample=downsample, drop_path_rates=stage_dpr[i], **stage_args)
]
in_ch_list[i],
stage_conv_chs[i],
stage_out_chs[i],
block_per_stage[i],
layer_per_block,
downsample=downsample,
drop_path_rates=stage_dpr[i],
**stage_args,
)]
self.num_features = stage_out_chs[i]
current_stride *= 2 if downsample else 1
self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
@ -324,7 +373,6 @@ class VovNet(nn.Module):
elif isinstance(m, nn.Linear):
nn.init.zeros_(m.bias)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(

@ -8,7 +8,7 @@ from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\
from .jit import set_jit_legacy, set_jit_fuser
from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg
from .misc import natural_key, add_bool_arg, ParseKwargs
from .model import unwrap_model, get_state_dict, freeze, unfreeze
from .model_ema import ModelEma, ModelEmaV2
from .random import random_seed

@ -2,6 +2,8 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import argparse
import ast
import re
@ -16,3 +18,15 @@ def add_bool_arg(parser, name, default=False, help=''):
group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
parser.set_defaults(**{dest_name: default})
class ParseKwargs(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
kw = {}
for value in values:
key, value = value.split('=')
try:
kw[key] = ast.literal_eval(value)
except ValueError:
kw[key] = str(value) # fallback to string (avoid need to escape on command line)
setattr(namespace, self.dest, kw)

@ -89,56 +89,58 @@ parser.add_argument('--data-dir', metavar='DIR',
parser.add_argument('--dataset', metavar='NAME', default='',
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
group.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)')
help='dataset validation split (default: validation)')
group.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
help='path to class to idx mapping file (default: "")')
# Model parameters
group = parser.add_argument_group('Model parameters')
group.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50")')
help='Name of model to train (default: "resnet50")')
group.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
help='Start with pretrained version of specified network (if avail)')
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
help='Initialize model from this checkpoint (default: none)')
group.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
help='Resume full model and optimizer state from checkpoint (default: none)')
group.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
help='prevent resume of optimizer state when resuming model')
group.add_argument('--num-classes', type=int, default=None, metavar='N',
help='number of label classes (Model default if None)')
help='number of label classes (Model default if None)')
group.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
group.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image size (default: None => model default)')
help='Image size (default: None => model default)')
group.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)')
help='Image input channels (default: None => 3)')
group.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
metavar='N N N',
help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
group.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
metavar='N', help='Input image center crop percent (for validation only)')
group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
help='Override mean pixel value of dataset')
group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of dataset')
help='Override std deviation of dataset')
group.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
help='Image resize interpolation type (overrides model)')
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='Input batch size for training (default: 128)')
help='Input batch size for training (default: 128)')
group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
help='Validation batch size override (default: None)')
help='Validation batch size override (default: None)')
group.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
help='Use channels_last memory layout')
group.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages')
help='Enable gradient checkpointing through model blocks/stages')
group.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
help='enable experimental fast-norm')
group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
@ -151,199 +153,200 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true
# Optimizer parameters
group = parser.add_argument_group('Optimizer parameters')
group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd")')
help='Optimizer (default: "sgd")')
group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
help='Optimizer Epsilon (default: None, use opt default)')
group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
help='Optimizer Betas (default: None, use opt default)')
group.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
help='Optimizer momentum (default: 0.9)')
group.add_argument('--weight-decay', type=float, default=2e-5,
help='weight decay (default: 2e-5)')
help='weight decay (default: 2e-5)')
group.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
help='Clip gradient norm (default: None, no clipping)')
group.add_argument('--clip-mode', type=str, default='norm',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
help='Gradient clipping mode. One of ("norm", "value", "agc")')
group.add_argument('--layer-decay', type=float, default=None,
help='layer-wise learning rate decay (default: None)')
help='layer-wise learning rate decay (default: None)')
group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
# Learning rate schedule parameters
group = parser.add_argument_group('Learning rate schedule parameters')
group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER',
help='LR scheduler (default: "step"')
help='LR scheduler (default: "step"')
group.add_argument('--sched-on-updates', action='store_true', default=False,
help='Apply LR scheduler step on update instead of epoch end.')
help='Apply LR scheduler step on update instead of epoch end.')
group.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate, overrides lr-base if set (default: None)')
help='learning rate, overrides lr-base if set (default: None)')
group.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
help='base learning rate: lr = lr_base * global_batch_size / base_size')
help='base learning rate: lr = lr_base * global_batch_size / base_size')
group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
help='base learning rate batch size (divisor, default: 256).')
help='base learning rate batch size (divisor, default: 256).')
group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
help='learning rate noise on/off epoch percentages')
group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
help='learning rate noise limit percent (default: 0.67)')
group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
help='learning rate noise std-dev (default: 1.0)')
group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
help='learning rate cycle len multiplier (default: 1.0)')
group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
help='amount to decay each learning rate cycle (default: 0.5)')
help='amount to decay each learning rate cycle (default: 0.5)')
group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit, cycles enabled if > 1')
help='learning rate cycle limit, cycles enabled if > 1')
group.add_argument('--lr-k-decay', type=float, default=1.0,
help='learning rate k-decay for cosine/poly (default: 1.0)')
help='learning rate k-decay for cosine/poly (default: 1.0)')
group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',
help='warmup learning rate (default: 1e-5)')
help='warmup learning rate (default: 1e-5)')
group.add_argument('--min-lr', type=float, default=0, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
group.add_argument('--epochs', type=int, default=300, metavar='N',
help='number of epochs to train (default: 300)')
help='number of epochs to train (default: 300)')
group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
group.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
help='manual epoch number (useful on restarts)')
group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES",
help='list of decay epoch indices for multistep lr. must be increasing')
help='list of decay epoch indices for multistep lr. must be increasing')
group.add_argument('--decay-epochs', type=float, default=90, metavar='N',
help='epoch interval to decay LR')
help='epoch interval to decay LR')
group.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
help='epochs to warmup LR, if scheduler supports')
group.add_argument('--warmup-prefix', action='store_true', default=False,
help='Exclude warmup period from decay schedule.'),
help='Exclude warmup period from decay schedule.'),
group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10)')
help='patience epochs for Plateau LR scheduler (default: 10)')
group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
help='LR decay rate (default: 0.1)')
# Augmentation & regularization parameters
group = parser.add_argument_group('Augmentation and regularization parameters')
group.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
help='Disable all training augmentation, override other train aug args')
group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
group.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
help='Random resize scale (default: 0.08 1.0)')
group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
group.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
help='Horizontal flip training aug probability')
group.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
help='Vertical flip training aug probability')
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
help='Color jitter factor (default: 0.4)')
group.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
group.add_argument('--aug-repeats', type=float, default=0,
help='Number of augmentation repetitions (distributed training only) (default: 0)')
help='Number of augmentation repetitions (distributed training only) (default: 0)')
group.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
group.add_argument('--jsd-loss', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
group.add_argument('--bce-loss', action='store_true', default=False,
help='Enable BCE loss w/ Mixup/CutMix use.')
help='Enable BCE loss w/ Mixup/CutMix use.')
group.add_argument('--bce-target-thresh', type=float, default=None,
help='Threshold for binarizing softened BCE targets (default: None, disabled)')
help='Threshold for binarizing softened BCE targets (default: None, disabled)')
group.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)')
help='Random erase prob (default: 0.)')
group.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
help='Random erase mode (default: "pixel")')
group.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
help='Random erase count (default: 1)')
group.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
help='Do not random erase first (clean) augmentation split')
group.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
group.add_argument('--cutmix', type=float, default=0.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
group.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
help='Probability of performing mixup or cutmix when either/both is enabled')
group.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
help='Probability of switching to cutmix when both mixup and cutmix enabled')
group.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
group.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
help='Label smoothing (default: 0.1)')
group.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
help='Training interpolation (random, bilinear, bicubic default: "random")')
group.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
help='Dropout rate (default: 0.)')
group.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
group.add_argument('--drop-path', type=float, default=None, metavar='PCT',
help='Drop path rate (default: None)')
help='Drop path rate (default: None)')
group.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
help='Drop block rate (default: None)')
# Batch norm parameters (only works with gen_efficientnet based models currently)
group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
group.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
help='BatchNorm momentum override (if not None)')
group.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
help='BatchNorm epsilon override (if not None)')
group.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
group.add_argument('--dist-bn', type=str, default='reduce',
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
group.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
group = parser.add_argument_group('Model exponential moving average parameters')
group.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
help='Enable tracking moving average of model weights')
group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
group.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)')
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
group = parser.add_argument_group('Miscellaneous parameters')
group.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
help='random seed (default: 42)')
group.add_argument('--worker-seeding', type=str, default='all',
help='worker seed mode (default: all)')
help='worker seed mode (default: all)')
group.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
help='how many batches to wait before logging training status')
group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
help='how many batches to wait before writing recovery checkpoint')
group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
help='number of checkpoints to keep (default: 10)')
help='number of checkpoints to keep (default: 10)')
group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 4)')
help='how many training processes to use (default: 4)')
group.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
help='save images of input bathes every log interval for debugging')
group.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
help='disable fast prefetcher')
group.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
help='path to output folder (default: none, current dir)')
group.add_argument('--experiment', default='', type=str, metavar='NAME',
help='name of train experiment, name of sub-folder for output')
help='name of train experiment, name of sub-folder for output')
group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
help='Best metric (default: "top1"')
group.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
group.add_argument("--local_rank", default=0, type=int)
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')
help='log training and validation metrics to wandb')
def _parse_args():
@ -371,8 +374,6 @@ def main():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
if args.data and not args.data_dir:
args.data_dir = args.data
args.prefetcher = not args.no_prefetcher
device = utils.init_distributed_device(args)
if args.distributed:
@ -383,14 +384,6 @@ def main():
_logger.info(f'Training with a single process on 1 device ({args.device}).')
assert args.rank >= 0
if utils.is_primary(args) and args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_dtype = torch.float16
@ -432,6 +425,7 @@ def main():
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint,
**args.model_kwargs,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
@ -504,7 +498,11 @@ def main():
f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
optimizer = create_optimizer_v2(
model,
**optimizer_kwargs(cfg=args),
**args.opt_kwargs,
)
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing
@ -559,6 +557,8 @@ def main():
# NOTE: EMA model does not need to be wrapped by DDP
# create the train and eval datasets
if args.data and not args.data_dir:
args.data_dir = args.data
dataset_train = create_dataset(
args.dataset,
root=args.data_dir,
@ -712,6 +712,14 @@ def main():
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
if utils.is_primary(args) and args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
# setup learning rate schedule and starting epoch
updates_per_epoch = len(loader_train)
lr_scheduler, num_epochs = create_scheduler_v2(

@ -26,7 +26,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, RealLa
from timm.layers import apply_test_time_pool, set_fast_norm
from timm.models import create_model, load_checkpoint, is_model, list_models
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \
decay_batch_step, check_batch_size_retry
decay_batch_step, check_batch_size_retry, ParseKwargs
try:
from apex import amp
@ -71,6 +71,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--use-train-size', action='store_true', default=False,
@ -123,6 +125,8 @@ parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true',
@ -181,13 +185,20 @@ def validate(args):
set_fast_norm()
# create model
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
in_chans=3,
in_chans=in_chans,
global_pool=args.gp,
scriptable=args.torchscript,
**args.model_kwargs,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
@ -232,8 +243,9 @@ def validate(args):
criterion = nn.CrossEntropyLoss().to(device)
root_dir = args.data or args.data_dir
dataset = create_dataset(
root=args.data,
root=root_dir,
name=args.dataset,
split=args.split,
download=args.dataset_download,
@ -389,7 +401,7 @@ def main():
if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints
args.pretrained = True
model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*_dino'])
model_names = list_models('convnext*', pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*in12k', '*_dino', '*fcmae'])
model_cfgs = [(n, '') for n in model_names]
elif not is_model(args.model):
# model name doesn't exist, try as wildcard filter

Loading…
Cancel
Save