diff --git a/timm/data/dataset.py b/timm/data/dataset.py index e7f67925..17c08e4d 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -88,6 +88,7 @@ class IterableImageDataset(data.IterableDataset): root, reader=None, split='train', + class_map=None, is_training=False, batch_size=None, seed=42, @@ -102,6 +103,7 @@ class IterableImageDataset(data.IterableDataset): reader, root=root, split=split, + class_map=class_map, is_training=is_training, batch_size=batch_size, seed=seed, diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index a4c18e39..6f0dcfcd 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -157,6 +157,7 @@ def create_dataset( root, reader=name, split=split, + class_map=class_map, is_training=is_training, download=download, batch_size=batch_size, @@ -169,6 +170,7 @@ def create_dataset( root, reader=name, split=split, + class_map=class_map, is_training=is_training, batch_size=batch_size, repeats=repeats, diff --git a/timm/data/readers/reader_tfds.py b/timm/data/readers/reader_tfds.py index 25aab471..012a27a9 100644 --- a/timm/data/readers/reader_tfds.py +++ b/timm/data/readers/reader_tfds.py @@ -34,6 +34,7 @@ except ImportError as e: print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") exit(1) +from .class_map import load_class_map from .reader import Reader from .shared_count import SharedCount @@ -94,6 +95,7 @@ class ReaderTfds(Reader): root, name, split='train', + class_map=None, is_training=False, batch_size=None, download=False, @@ -151,7 +153,12 @@ class ReaderTfds(Reader): # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag if download: self.builder.download_and_prepare() - self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {} + self.remap_class = False + if class_map: + self.class_to_idx = load_class_map(class_map) + self.remap_class = True + else: + self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {} self.split_info = self.builder.info.splits[split] self.num_samples = self.split_info.num_examples @@ -299,6 +306,8 @@ class ReaderTfds(Reader): target_data = sample[self.target_name] if self.target_img_mode: target_data = Image.fromarray(target_data, mode=self.target_img_mode) + elif self.remap_class: + target_data = self.class_to_idx[target_data] yield input_data, target_data sample_count += 1 if self.is_training and sample_count >= target_sample_count: diff --git a/timm/data/readers/reader_wds.py b/timm/data/readers/reader_wds.py index 36890eed..3bf99d26 100644 --- a/timm/data/readers/reader_wds.py +++ b/timm/data/readers/reader_wds.py @@ -29,6 +29,7 @@ except ImportError: wds = None expand_urls = None +from .class_map import load_class_map from .reader import Reader from .shared_count import SharedCount @@ -42,13 +43,13 @@ def _load_info(root, basename='info'): info_yaml = os.path.join(root, basename + '.yaml') err_str = '' try: - with wds.gopen.gopen(info_json) as f: + with wds.gopen(info_json) as f: info_dict = json.load(f) return info_dict except Exception as e: err_str = str(e) try: - with wds.gopen.gopen(info_yaml) as f: + with wds.gopen(info_yaml) as f: info_dict = yaml.safe_load(f) return info_dict except Exception: @@ -110,8 +111,8 @@ def _parse_split_info(split: str, info: Dict): filenames=split_filenames, ) else: - if split not in info['splits']: - raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})") + if 'splits' not in info or split not in info['splits']: + raise RuntimeError(f"split {split} not found in info ({info.get('splits', {}).keys()})") split = split split_info = info['splits'][split] split_info = _info_convert(split_info) @@ -290,6 +291,7 @@ class ReaderWds(Reader): batch_size=None, repeats=0, seed=42, + class_map=None, input_name='jpg', input_image='RGB', target_name='cls', @@ -320,6 +322,12 @@ class ReaderWds(Reader): self.num_samples = self.split_info.num_samples if not self.num_samples: raise RuntimeError(f'Invalid split definition, no samples found.') + self.remap_class = False + if class_map: + self.class_to_idx = load_class_map(class_map) + self.remap_class = True + else: + self.class_to_idx = {} # Distributed world state self.dist_rank = 0 @@ -431,7 +439,10 @@ class ReaderWds(Reader): i = 0 # _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug for sample in ds: - yield sample[self.image_key], sample[self.target_key] + target = sample[self.target_key] + if self.remap_class: + target = self.class_to_idx[target] + yield sample[self.image_key], target i += 1 # _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug diff --git a/timm/models/__init__.py b/timm/models/__init__.py index a9fbbc26..05cbbc81 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -17,6 +17,7 @@ from .edgenext import * from .efficientformer import * from .efficientformer_v2 import * from .efficientnet import * +from .focalnet import * from .gcvit import * from .ghostnet import * from .gluon_resnet import * diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py new file mode 100644 index 00000000..8178cfc3 --- /dev/null +++ b/timm/models/focalnet.py @@ -0,0 +1,637 @@ +""" FocalNet + +As described in `Focal Modulation Networks` - https://arxiv.org/abs/2203.11926 + +Significant modifications and refactoring from the original impl at https://github.com/microsoft/FocalNet + +This impl is/has: +* fully convolutional, NCHW tensor layout throughout, seemed to have minimal performance impact but more flexible +* re-ordered downsample / layer so that striding always at beginning of layer (stage) +* no input size constraints or input resolution/H/W tracking through the model +* torchscript fixed and a number of quirks cleaned up +* feature extraction support via `features_only=True` +""" +# -------------------------------------------------------- +# FocalNets -- Focal Modulation Networks +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Jianwei Yang (jianwyan@microsoft.com) +# -------------------------------------------------------- +from functools import partial + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead +from ._builder import build_model_with_cfg +from ._manipulate import named_apply +from ._registry import register_model + +__all__ = ['FocalNet'] + + +class FocalModulation(nn.Module): + def __init__( + self, + dim, + focal_window, + focal_level, + focal_factor=2, + bias=True, + use_post_norm=False, + normalize_modulator=False, + proj_drop=0., + norm_layer=LayerNorm2d, + ): + super().__init__() + + self.dim = dim + self.focal_window = focal_window + self.focal_level = focal_level + self.focal_factor = focal_factor + self.use_post_norm = use_post_norm + self.normalize_modulator = normalize_modulator + self.input_split = [dim, dim, self.focal_level + 1] + + self.f = nn.Conv2d(dim, 2 * dim + (self.focal_level + 1), kernel_size=1, bias=bias) + self.h = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + self.act = nn.GELU() + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + self.proj_drop = nn.Dropout(proj_drop) + self.focal_layers = nn.ModuleList() + + self.kernel_sizes = [] + for k in range(self.focal_level): + kernel_size = self.focal_factor * k + self.focal_window + self.focal_layers.append(nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=kernel_size, groups=dim, padding=kernel_size // 2, bias=False), + nn.GELU(), + )) + self.kernel_sizes.append(kernel_size) + self.norm = norm_layer(dim) if self.use_post_norm else nn.Identity() + + def forward(self, x): + """ + Args: + x: input features with shape of (B, H, W, C) + """ + C = x.shape[1] + + # pre linear projection + x = self.f(x) + q, ctx, gates = torch.split(x, self.input_split, 1) + + # context aggreation + ctx_all = 0 + for l, focal_layer in enumerate(self.focal_layers): + ctx = focal_layer(ctx) + ctx_all = ctx_all + ctx * gates[:, l:l + 1] + ctx_global = self.act(ctx.mean((2, 3), keepdim=True)) + ctx_all = ctx_all + ctx_global * gates[:, self.focal_level:] + + # normalize context + if self.normalize_modulator: + ctx_all = ctx_all / (self.focal_level + 1) + + # focal modulation + x_out = q * self.h(ctx_all) + x_out = self.norm(x_out) + + # post linear projection + x_out = self.proj(x_out) + x_out = self.proj_drop(x_out) + return x_out + + +class LayerScale2d(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + gamma = self.gamma.view(1, -1, 1, 1) + return x.mul_(gamma) if self.inplace else x * gamma + + +class FocalNetBlock(nn.Module): + r""" Focal Modulation Network Block. + + Args: + dim (int): Number of input channels. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + proj_drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + focal_level (int): Number of focal levels. + focal_window (int): Focal window size at first focal level + layerscale_value (float): Initial layerscale value + use_post_norm (bool): Whether to use layernorm after modulation + """ + + def __init__( + self, + dim, + mlp_ratio=4., + focal_level=1, + focal_window=3, + use_post_norm=False, + use_post_norm_in_modulation=False, + normalize_modulator=False, + layerscale_value=1e-4, + proj_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=LayerNorm2d, + ): + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + + self.focal_window = focal_window + self.focal_level = focal_level + self.use_post_norm = use_post_norm + + self.norm1 = norm_layer(dim) if not use_post_norm else nn.Identity() + self.modulation = FocalModulation( + dim, + focal_window=focal_window, + focal_level=self.focal_level, + use_post_norm=use_post_norm_in_modulation, + normalize_modulator=normalize_modulator, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.norm1_post = norm_layer(dim) if use_post_norm else nn.Identity() + self.ls1 = LayerScale2d(dim, layerscale_value) if layerscale_value is not None else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) if not use_post_norm else nn.Identity() + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + use_conv=True, + ) + self.norm2_post = norm_layer(dim) if use_post_norm else nn.Identity() + self.ls2 = LayerScale2d(dim, layerscale_value) if layerscale_value is not None else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + + # Focal Modulation + x = self.norm1(x) + x = self.modulation(x) + x = self.norm1_post(x) + x = shortcut + self.drop_path1(self.ls1(x)) + + # FFN + x = x + self.drop_path2(self.ls2(self.norm2_post(self.mlp(self.norm2(x))))) + + return x + + +class BasicLayer(nn.Module): + """ A basic Focal Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + depth (int): Number of blocks. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (bool): Downsample layer at start of the layer. Default: True + focal_level (int): Number of focal levels + focal_window (int): Focal window size at first focal level + layerscale_value (float): Initial layerscale value + use_post_norm (bool): Whether to use layer norm after modulation + """ + + def __init__( + self, + dim, + out_dim, + depth, + mlp_ratio=4., + downsample=True, + focal_level=1, + focal_window=1, + use_overlap_down=False, + use_post_norm=False, + use_post_norm_in_modulation=False, + normalize_modulator=False, + layerscale_value=1e-4, + proj_drop=0., + drop_path=0., + norm_layer=LayerNorm2d, + ): + + super().__init__() + self.dim = dim + self.depth = depth + self.grad_checkpointing = False + + if downsample: + self.downsample = Downsample( + in_chs=dim, + out_chs=out_dim, + stride=2, + overlap=use_overlap_down, + norm_layer=norm_layer, + ) + else: + self.downsample = nn.Identity() + + # build blocks + self.blocks = nn.ModuleList([ + FocalNetBlock( + dim=out_dim, + mlp_ratio=mlp_ratio, + focal_level=focal_level, + focal_window=focal_window, + use_post_norm=use_post_norm, + use_post_norm_in_modulation=use_post_norm_in_modulation, + normalize_modulator=normalize_modulator, + layerscale_value=layerscale_value, + proj_drop=proj_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth)]) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + def forward(self, x): + x = self.downsample(x) + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + return x + + +class Downsample(nn.Module): + r""" + Args: + in_chs (int): Number of input image channels + out_chs (int): Number of linear projection output channels + stride (int): Downsample stride. Default: 4. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, + in_chs, + out_chs, + stride=4, + overlap=False, + norm_layer=None, + ): + super().__init__() + self.stride = stride + padding = 0 + kernel_size = stride + if overlap: + assert stride in (2, 4) + if stride == 4: + kernel_size, padding = 7, 2 + elif stride == 2: + kernel_size, padding = 3, 1 + self.proj = nn.Conv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, padding=padding) + self.norm = norm_layer(out_chs) if norm_layer is not None else nn.Identity() + + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + + +class FocalNet(nn.Module): + r""" Focal Modulation Networks (FocalNets) + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Focal Transformer layer. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + focal_levels (list): How many focal levels at all stages. Note that this excludes the finest-grain level. + Default: [1, 1, 1, 1] + focal_windows (list): The focal window size at all stages. Default: [7, 5, 3, 1] + use_overlap_down (bool): Whether to use convolutional embedding. + use_post_norm (bool): Whether to use layernorm after modulation (it helps stablize training of large models) + layerscale_value (float): Value for layer scale. Default: 1e-4 + drop_rate (float): Dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + + """ + + def __init__( + self, + in_chans=3, + num_classes=1000, + global_pool='avg', + embed_dim=96, + depths=(2, 2, 6, 2), + mlp_ratio=4., + focal_levels=(2, 2, 2, 2), + focal_windows=(3, 3, 3, 3), + use_overlap_down=False, + use_post_norm=False, + use_post_norm_in_modulation=False, + normalize_modulator=False, + head_hidden_size=None, + head_init_scale=1.0, + layerscale_value=None, + drop_rate=0., + proj_drop_rate=0., + drop_path_rate=0.1, + norm_layer=partial(LayerNorm2d, eps=1e-5), + **kwargs, + ): + super().__init__() + + self.num_layers = len(depths) + embed_dim = [embed_dim * (2 ** i) for i in range(self.num_layers)] + + self.num_classes = num_classes + self.embed_dim = embed_dim + self.num_features = embed_dim[-1] + self.feature_info = [] + + self.stem = Downsample( + in_chs=in_chans, + out_chs=embed_dim[0], + overlap=use_overlap_down, + norm_layer=norm_layer, + ) + in_dim = embed_dim[0] + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + layers = [] + for i_layer in range(self.num_layers): + out_dim = embed_dim[i_layer] + layer = BasicLayer( + dim=in_dim, + out_dim=out_dim, + depth=depths[i_layer], + mlp_ratio=mlp_ratio, + downsample=i_layer > 0, + focal_level=focal_levels[i_layer], + focal_window=focal_windows[i_layer], + use_overlap_down=use_overlap_down, + use_post_norm=use_post_norm, + use_post_norm_in_modulation=use_post_norm_in_modulation, + normalize_modulator=normalize_modulator, + layerscale_value=layerscale_value, + proj_drop=proj_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + ) + in_dim = out_dim + layers += [layer] + self.feature_info += [dict(num_chs=out_dim, reduction=4 * 2 ** i_layer, module=f'layers.{i_layer}')] + + self.layers = nn.Sequential(*layers) + + if head_hidden_size: + self.norm = nn.Identity() + self.head = NormMlpClassifierHead( + self.num_features, + num_classes, + hidden_size=head_hidden_size, + pool_type=global_pool, + drop_rate=drop_rate, + norm_layer=norm_layer, + ) + else: + self.norm = norm_layer(self.num_features) + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate + ) + + named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) + + @torch.jit.ignore + def no_weight_decay(self): + return {''} + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + for l in self.layers: + l.set_grad_checkpointing(enable=enable) + + @torch.jit.ignore + def get_classifier(self): + return self.classifier.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.classifier.reset(num_classes, global_pool=global_pool) + + def forward_features(self, x): + x = self.stem(x) + x = self.layers(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _init_weights(module, name=None, head_init_scale=1.0): + if isinstance(module, nn.Conv2d): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + if name and 'head.fc' in name: + module.weight.data.mul_(head_init_scale) + module.bias.data.mul_(head_init_scale) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.proj', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = { + "focalnet_tiny_srf": _cfg( + url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_tiny_srf.pth'), + "focalnet_small_srf": _cfg( + url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_srf.pth'), + "focalnet_base_srf": _cfg( + url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_srf.pth'), + "focalnet_tiny_lrf": _cfg( + url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_tiny_lrf.pth'), + "focalnet_small_lrf": _cfg( + url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_lrf.pth'), + "focalnet_base_lrf": _cfg( + url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_lrf.pth'), + "focalnet_large_fl3": _cfg( + url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384.pth', + input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842), + "focalnet_large_fl4": _cfg( + url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384_fl4.pth', + input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842), + "focalnet_xlarge_fl3": _cfg( + url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384.pth', + input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842), + "focalnet_xlarge_fl4": _cfg( + url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384_fl4.pth', + input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842), + "focalnet_huge_fl3": _cfg( + url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224.pth', + num_classes=0), + "focalnet_huge_fl4": _cfg( + url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224_fl4.pth', + num_classes=0), +} + + +def checkpoint_filter_fn(state_dict, model: FocalNet): + if 'stem.proj.weight' in state_dict: + return + import re + out_dict = {} + if 'model' in state_dict: + state_dict = state_dict['model'] + dest_dict = model.state_dict() + for k, v in state_dict.items(): + k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) + k = k.replace('patch_embed', 'stem') + k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) + if 'norm' in k and k not in dest_dict: + k = re.sub(r'norm([0-9])', r'norm\1_post', k) + k = k.replace('ln.', 'norm.') + k = k.replace('head', 'head.fc') + if dest_dict[k].shape != v.shape: + v = v.reshape(dest_dict[k].shape) + out_dict[k] = v + return out_dict + + +def _create_focalnet(variant, pretrained=False, **kwargs): + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + + model = build_model_with_cfg( + FocalNet, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs) + return model + + +@register_model +def focalnet_tiny_srf(pretrained=False, **kwargs): + model_kwargs = dict(depths=[2, 2, 6, 2], embed_dim=96, **kwargs) + return _create_focalnet('focalnet_tiny_srf', pretrained=pretrained, **model_kwargs) + + +@register_model +def focalnet_small_srf(pretrained=False, **kwargs): + model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=96, **kwargs) + return _create_focalnet('focalnet_small_srf', pretrained=pretrained, **model_kwargs) + + +@register_model +def focalnet_base_srf(pretrained=False, **kwargs): + model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=128, **kwargs) + return _create_focalnet('focalnet_base_srf', pretrained=pretrained, **model_kwargs) + + +@register_model +def focalnet_tiny_lrf(pretrained=False, **kwargs): + model_kwargs = dict(depths=[2, 2, 6, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs) + return _create_focalnet('focalnet_tiny_lrf', pretrained=pretrained, **model_kwargs) + + +@register_model +def focalnet_small_lrf(pretrained=False, **kwargs): + model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs) + return _create_focalnet('focalnet_small_lrf', pretrained=pretrained, **model_kwargs) + + +@register_model +def focalnet_base_lrf(pretrained=False, **kwargs): + model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], **kwargs) + return _create_focalnet('focalnet_base_lrf', pretrained=pretrained, **model_kwargs) + + +# FocalNet large+ models +@register_model +def focalnet_large_fl3(pretrained=False, **kwargs): + model_kwargs = dict( + depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[3, 3, 3, 3], focal_windows=[5] * 4, + use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs) + return _create_focalnet('focalnet_large_fl3', pretrained=pretrained, **model_kwargs) + + +@register_model +def focalnet_large_fl4(pretrained=False, **kwargs): + model_kwargs = dict( + depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[4, 4, 4, 4], + use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs) + return _create_focalnet('focalnet_large_fl4', pretrained=pretrained, **model_kwargs) + + +@register_model +def focalnet_xlarge_fl3(pretrained=False, **kwargs): + model_kwargs = dict( + depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[3, 3, 3, 3], focal_windows=[5] * 4, + use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs) + return _create_focalnet('focalnet_xlarge_fl3', pretrained=pretrained, **model_kwargs) + + +@register_model +def focalnet_xlarge_fl4(pretrained=False, **kwargs): + model_kwargs = dict( + depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[4, 4, 4, 4], + use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs) + return _create_focalnet('focalnet_xlarge_fl4', pretrained=pretrained, **model_kwargs) + + +@register_model +def focalnet_huge_fl3(pretrained=False, **kwargs): + model_kwargs = dict( + depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[3, 3, 3, 3], focal_windows=[3] * 4, + use_post_norm=True, use_post_norm_in_modulation=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs) + return _create_focalnet('focalnet_huge_fl3', pretrained=pretrained, **model_kwargs) + + +@register_model +def focalnet_huge_fl4(pretrained=False, **kwargs): + model_kwargs = dict( + depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[4, 4, 4, 4], + use_post_norm=True, use_post_norm_in_modulation=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs) + return _create_focalnet('focalnet_huge_fl4', pretrained=pretrained, **model_kwargs) +