Compare commits

...

9 Commits

@ -23,11 +23,14 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New ## What's New
### Jan 6, 2022 ### Jan 14, 2022
* Version 0.5.2 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon.... * Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon....
* Tried training a few small / mobile optimized models, a few are good so far, more on the way... * Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features
* Tried training a few small (~1.8-3M param) / mobile optimized models, a few are good so far, more on the way...
* `mnasnet_small` - 65.6 top-1 * `mnasnet_small` - 65.6 top-1
* `lcnet_100` - 72.1 top-1 * `mobilenetv2_050` - 65.9
* `lcnet_100/075/050` - 72.1 / 68.8 / 63.1
* `semnasnet_075` - 73
* `fbnetv3_b/d/g` - 79.1 / 79.7 / 82.0 * `fbnetv3_b/d/g` - 79.1 / 79.7 / 82.0
* TinyNet models added by [rsomani95](https://github.com/rsomani95) * TinyNet models added by [rsomani95](https://github.com/rsomani95)
* LCNet added via MobileNetV3 architecture * LCNet added via MobileNetV3 architecture

@ -28,12 +28,12 @@ NON_STD_FILTERS = [
NUM_NON_STD = len(NON_STD_FILTERS) NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures # exclude models that cause specific test failures
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): if 'GITHUB_ACTIONS' in os.environ:
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FILTERS = [ EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*'] '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*']
else: else:
EXCLUDE_FILTERS = [] EXCLUDE_FILTERS = []
@ -255,7 +255,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
EXCLUDE_JIT_FILTERS = [ EXCLUDE_JIT_FILTERS = [
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable '*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point 'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point
'vit_large_*', 'vit_huge_*', 'vit_large_*', 'vit_huge_*', 'vit_gi*',
] ]
@ -334,7 +334,7 @@ def _create_fx_model(model, train=False):
return fx_model return fx_model
EXCLUDE_FX_FILTERS = [] EXCLUDE_FX_FILTERS = ['vit_gi*']
# not enough memory to run fx on more models than other tests # not enough memory to run fx on more models than other tests
if 'GITHUB_ACTIONS' in os.environ: if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FX_FILTERS += [ EXCLUDE_FX_FILTERS += [

@ -15,7 +15,7 @@ Papers:
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781 AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
import random import random
import math import math

@ -1,6 +1,6 @@
""" Quick n Simple Image Folder, Tarfile based DataSet """ Quick n Simple Image Folder, Tarfile based DataSet
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
import torch.utils.data as data import torch.utils.data as data
import os import os

@ -1,3 +1,7 @@
""" Dataset Factory
Hacked together by / Copyright 2021, Ross Wightman
"""
import os import os
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder

@ -3,7 +3,7 @@
Prefetcher and Fast Collate inspired by NVIDIA APEX example at Prefetcher and Fast Collate inspired by NVIDIA APEX example at
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
Hacked together by / Copyright 2021 Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
import random import random
from functools import partial from functools import partial

@ -8,7 +8,7 @@ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Fea
Code Reference: Code Reference:
CutMix: https://github.com/clovaai/CutMix-PyTorch CutMix: https://github.com/clovaai/CutMix-PyTorch
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
import numpy as np import numpy as np
import torch import torch

@ -3,7 +3,7 @@
Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0
Copyright Zhun Zhong & Liang Zheng Copyright Zhun Zhong & Liang Zheng
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
import random import random
import math import math

@ -1,7 +1,7 @@
""" Transforms Factory """ Transforms Factory
Factory methods for building image transforms for use with TIMM (PyTorch Image Models) Factory methods for building image transforms for use with TIMM (PyTorch Image Models)
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
import math import math

@ -5,6 +5,7 @@ from .cait import *
from .coat import * from .coat import *
from .convit import * from .convit import *
from .convmixer import * from .convmixer import *
from .convnext import *
from .crossvit import * from .crossvit import *
from .cspnet import * from .cspnet import *
from .densenet import * from .densenet import *

@ -4,6 +4,7 @@ Paper: 'Going deeper with Image Transformers' - https://arxiv.org/abs/2103.17239
Original code and weights from https://github.com/facebookresearch/deit, copyright below Original code and weights from https://github.com/facebookresearch/deit, copyright below
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
""" """
# Copyright (c) 2015-present, Facebook, Inc. # Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved. # All rights reserved.

@ -9,6 +9,8 @@
Paper link: https://arxiv.org/abs/2103.10697 Paper link: https://arxiv.org/abs/2103.10697
Original code: https://github.com/facebookresearch/convit, original copyright below Original code: https://github.com/facebookresearch/convit, original copyright below
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
""" """
# Copyright (c) 2015-present, Facebook, Inc. # Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved. # All rights reserved.

@ -0,0 +1,368 @@
""" ConvNeXt
Paper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below
Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
"""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the MIT license
from collections import OrderedDict
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_module
from .helpers import named_apply, build_model_with_cfg
from .layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp
from .registry import register_model
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = dict(
convnext_tiny=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth"),
convnext_small=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth"),
convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"),
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
convnext_tiny_hnf=_cfg(url=''),
convnext_base_in22k=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
convnext_large_in22k=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841),
convnext_xlarge_in22k=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841),
)
def _is_contiguous(tensor: torch.Tensor) -> bool:
# jit is oh so lovely :/
# if torch.jit.is_tracing():
# return True
if torch.jit.is_scripting():
return tensor.is_contiguous()
else:
return tensor.is_contiguous(memory_format=torch.contiguous_format)
@register_notrace_module
class LayerNorm2d(nn.LayerNorm):
r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
"""
def __init__(self, normalized_shape, eps=1e-6):
super().__init__(normalized_shape, eps=eps)
def forward(self, x) -> torch.Tensor:
if _is_contiguous(x):
return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
else:
s, u = torch.var_mean(x, dim=1, keepdim=True)
x = (x - u) * torch.rsqrt(s + self.eps)
x = x * self.weight[:, None, None] + self.bias[:, None, None]
return x
class ConvNeXtBlock(nn.Module):
""" ConvNeXt Block
There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=True, mlp_ratio=4, norm_layer=None):
super().__init__()
if not norm_layer:
norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
mlp_layer = ConvMlp if conv_mlp else Mlp
self.use_conv_mlp = conv_mlp
self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.norm = norm_layer(dim)
self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU)
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
x = self.conv_dw(x)
if self.use_conv_mlp:
x = self.norm(x)
x = self.mlp(x)
else:
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = self.mlp(x)
x = x.permute(0, 3, 1, 2)
if self.gamma is not None:
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
x = self.drop_path(x) + shortcut
return x
class ConvNeXtStage(nn.Module):
def __init__(
self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=True,
norm_layer=None, cl_norm_layer=None, cross_stage=False):
super().__init__()
if in_chs != out_chs or stride > 1:
self.downsample = nn.Sequential(
norm_layer(in_chs),
nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride),
)
else:
self.downsample = nn.Identity()
dp_rates = dp_rates or [0.] * depth
self.blocks = nn.Sequential(*[ConvNeXtBlock(
dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp,
norm_layer=norm_layer if conv_mlp else cl_norm_layer)
for j in range(depth)]
)
def forward(self, x):
x = self.downsample(x)
x = self.blocks(x)
return x
class ConvNeXt(nn.Module):
r""" ConvNeXt
A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_rate (float): Head dropout rate
drop_path_rate (float): Stochastic depth rate. Default: 0.
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(
self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4,
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=True,
head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0.,
):
super().__init__()
assert output_stride == 32
if norm_layer is None:
norm_layer = partial(LayerNorm2d, eps=1e-6)
cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
else:
assert conv_mlp,\
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
cl_norm_layer = norm_layer
self.num_classes = num_classes
self.drop_rate = drop_rate
self.feature_info = []
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size),
norm_layer(dims[0])
)
self.stages = nn.Sequential()
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
curr_stride = patch_size
prev_chs = dims[0]
stages = []
# 4 feature resolution stages, each consisting of multiple residual blocks
for i in range(4):
stride = 2 if i > 0 else 1
# FIXME support dilation / output_stride
curr_stride *= stride
out_chs = dims[i]
stages.append(ConvNeXtStage(
prev_chs, out_chs, stride=stride,
depth=depths[i], dp_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp,
norm_layer=norm_layer, cl_norm_layer=cl_norm_layer)
)
prev_chs = out_chs
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.num_features = prev_chs
if head_norm_first:
# norm -> global pool -> fc ordering, like most other nets (not compat with FB weights)
self.norm_pre = norm_layer(self.num_features) # final norm layer, before pooling
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
else:
# pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
self.norm_pre = nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
]))
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes=0, global_pool='avg'):
if isinstance(self.head, ClassifierHead):
# norm -> global pool -> fc
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
else:
# pool -> norm -> fc
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', self.head.norm),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
]))
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.norm_pre(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _init_weights(module, name=None, head_init_scale=1.0):
if isinstance(module, nn.Conv2d):
trunc_normal_(module.weight, std=.02)
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02)
nn.init.constant_(module.bias, 0)
if name and 'head.' in name:
module.weight.data.mul_(head_init_scale)
module.bias.data.mul_(head_init_scale)
def checkpoint_filter_fn(state_dict, model):
""" Remap FB checkpoints -> timm """
if 'model' in state_dict:
state_dict = state_dict['model']
out_dict = {}
import re
for k, v in state_dict.items():
k = k.replace('downsample_layers.0.', 'stem.')
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
k = k.replace('dwconv', 'conv_dw')
k = k.replace('pwconv', 'mlp.fc')
k = k.replace('head.', 'head.fc.')
if k.startswith('norm.'):
k = k.replace('norm', 'head.norm')
if v.ndim == 2 and 'head' not in k:
model_shape = model.state_dict()[k].shape
v = v.reshape(model_shape)
out_dict[k] = v
return out_dict
def _create_convnext(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(
ConvNeXt, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
**kwargs)
return model
@register_model
def convnext_tiny(pretrained=False, **kwargs):
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_tiny_hnf(pretrained=False, **kwargs):
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, **kwargs)
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_small(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
model = _create_convnext('convnext_small', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_base(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_base', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_large(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
model = _create_convnext('convnext_large', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_base_in22k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_base_in22k', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_large_in22k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
model = _create_convnext('convnext_large_in22k', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_xlarge_in22k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], conv_mlp=False, **kwargs)
model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args)
return model

@ -12,6 +12,8 @@ Paper link: https://arxiv.org/abs/2103.14899
Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py
NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408 NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
""" """
# Copyright IBM All Rights Reserved. # Copyright IBM All Rights Reserved.

@ -409,7 +409,7 @@ class CspNet(nn.Module):
def _create_cspnet(variant, pretrained=False, **kwargs): def _create_cspnet(variant, pretrained=False, **kwargs):
cfg_variant = variant.split('_')[0] cfg_variant = variant.split('_')[0]
# NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5] # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5]
out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5) if 'darknet' in variant else (0, 1, 2, 3, 4)) out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5) if 'darknet' in variant else (0, 1, 2, 3, 4))
return build_model_with_cfg( return build_model_with_cfg(
CspNet, variant, pretrained, CspNet, variant, pretrained,
default_cfg=default_cfgs[variant], default_cfg=default_cfgs[variant],

@ -33,7 +33,7 @@ The majority of the above models (EfficientNet*, MixNet, MnasNet) and original w
by Mingxing Tan, Quoc Le, and other members of their Google Brain team. Thanks for consistently releasing by Mingxing Tan, Quoc Le, and other members of their Google Brain team. Thanks for consistently releasing
the models and weights open source! the models and weights open source!
Hacked together by / Copyright 2021 Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
from functools import partial from functools import partial
from typing import List from typing import List
@ -73,7 +73,8 @@ default_cfgs = {
'mnasnet_140': _cfg(url=''), 'mnasnet_140': _cfg(url=''),
'semnasnet_050': _cfg(url=''), 'semnasnet_050': _cfg(url=''),
'semnasnet_075': _cfg(url=''), 'semnasnet_075': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/semnasnet_075-18710866.pth'),
'semnasnet_100': _cfg( 'semnasnet_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'),
'semnasnet_140': _cfg(url=''), 'semnasnet_140': _cfg(url=''),
@ -83,7 +84,9 @@ default_cfgs = {
'mobilenetv2_035': _cfg( 'mobilenetv2_035': _cfg(
url=''), url=''),
'mobilenetv2_050': _cfg( 'mobilenetv2_050': _cfg(
url=''), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_050-3d30d450.pth',
interpolation='bicubic',
),
'mobilenetv2_075': _cfg( 'mobilenetv2_075': _cfg(
url=''), url=''),
'mobilenetv2_100': _cfg( 'mobilenetv2_100': _cfg(
@ -718,7 +721,7 @@ def _gen_mobilenet_v2(
round_chs_fn = partial(round_channels, multiplier=channel_multiplier) round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
model_kwargs = dict( model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head), block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
num_features=1280 if fix_stem_head else round_chs_fn(1280), num_features=1280 if fix_stem_head else max(1280, round_chs_fn(1280)),
stem_size=32, stem_size=32,
fix_stem=fix_stem_head, fix_stem=fix_stem_head,
round_chs_fn=round_chs_fn, round_chs_fn=round_chs_fn,

@ -1,6 +1,6 @@
""" EfficientNet, MobileNetV3, etc Blocks """ EfficientNet, MobileNetV3, etc Blocks
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
import torch import torch

@ -3,7 +3,7 @@
Assembles EfficieNet and related network feature blocks from string definitions. Assembles EfficieNet and related network feature blocks from string definitions.
Handles stride, dilation calculations, and selects feature extraction points. Handles stride, dilation calculations, and selects feature extraction points.
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
import logging import logging

@ -21,7 +21,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
from .inplace_abn import InplaceAbn from .inplace_abn import InplaceAbn
from .linear import Linear from .linear import Linear
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, LayerNorm2d from .norm import GroupNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct from .norm_act import BatchNormAct2d, GroupNormAct

@ -14,7 +14,7 @@ Adapted from official impl at https://github.com/facebookresearch/LeViT, origina
This version combines both conv/linear models and fixes torchscript compatibility. This version combines both conv/linear models and fixes torchscript compatibility.
Modifications by/coyright Copyright 2021 Ross Wightman Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
""" """
# Copyright (c) 2015-present, Facebook, Inc. # Copyright (c) 2015-present, Facebook, Inc.

@ -1,11 +1,10 @@
""" MobileNet V3 """ MobileNet V3
A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl. A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.
Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by / Copyright 2021 Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
from functools import partial from functools import partial
from typing import List from typing import List
@ -86,8 +85,14 @@ default_cfgs = {
input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95), input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95),
"lcnet_035": _cfg(), "lcnet_035": _cfg(),
"lcnet_050": _cfg(), "lcnet_050": _cfg(
"lcnet_075": _cfg(), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth',
interpolation='bicubic',
),
"lcnet_075": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth',
interpolation='bicubic',
),
"lcnet_100": _cfg( "lcnet_100": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth',
interpolation='bicubic', interpolation='bicubic',
@ -111,9 +116,10 @@ class MobileNetV3(nn.Module):
* LCNet - https://arxiv.org/abs/2109.15099 * LCNet - https://arxiv.org/abs/2109.15099
""" """
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, def __init__(
pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True, self, block_args, num_classes=1000, in_chans=3, stem_size=16, fix_stem=False, num_features=1280,
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'): head_bias=True, pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
super(MobileNetV3, self).__init__() super(MobileNetV3, self).__init__()
act_layer = act_layer or nn.ReLU act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
@ -123,7 +129,8 @@ class MobileNetV3(nn.Module):
self.drop_rate = drop_rate self.drop_rate = drop_rate
# Stem # Stem
stem_size = round_chs_fn(stem_size) if not fix_stem:
stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size) self.bn1 = norm_layer(stem_size)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
@ -189,8 +196,8 @@ class MobileNetV3Features(nn.Module):
""" """
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=True, stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.): se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
super(MobileNetV3Features, self).__init__() super(MobileNetV3Features, self).__init__()
act_layer = act_layer or nn.ReLU act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
@ -198,7 +205,8 @@ class MobileNetV3Features(nn.Module):
self.drop_rate = drop_rate self.drop_rate = drop_rate
# Stem # Stem
stem_size = round_chs_fn(stem_size) if not fix_stem:
stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size) self.bn1 = norm_layer(stem_size)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
@ -382,6 +390,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
block_args=decode_arch_def(arch_def), block_args=decode_arch_def(arch_def),
num_features=num_features, num_features=num_features,
stem_size=16, stem_size=16,
fix_stem=channel_multiplier < 0.75,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier), round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=act_layer, act_layer=act_layer,

@ -4,7 +4,8 @@ This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-C
additional dropout and dynamic global avg/max pool. additional dropout and dynamic global avg/max pool.
ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman
Copyright 2020 Ross Wightman
Copyright 2019, Ross Wightman
""" """
import math import math
from functools import partial from functools import partial

@ -4,6 +4,7 @@ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shi
Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
""" """
# -------------------------------------------------------- # --------------------------------------------------------
# Swin Transformer # Swin Transformer

@ -180,7 +180,7 @@ def _filter_fn(state_dict):
def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG: def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG:
cfg = variant.split('_')[0] cfg = variant.split('_')[0]
# NOTE: VGG is one of few models with stride==1 features w/ 6 out_indices [0..5] # NOTE: VGG is one of few models with stride==1 features w/ 6 out_indices [0..5]
out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5)) out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5))
model = build_model_with_cfg( model = build_model_with_cfg(
VGG, variant, pretrained, VGG, variant, pretrained,
default_cfg=default_cfgs[variant], default_cfg=default_cfgs[variant],

@ -4,6 +4,7 @@ Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.1
From original at https://github.com/danczs/Visformer From original at https://github.com/danczs/Visformer
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
""" """
from copy import deepcopy from copy import deepcopy

@ -20,7 +20,7 @@ for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
Hacked together by / Copyright 2021 Ross Wightman Hacked together by / Copyright 2020, Ross Wightman
""" """
import math import math
import logging import logging
@ -105,6 +105,10 @@ default_cfgs = {
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0), input_size=(3, 384, 384), crop_pct=1.0),
'vit_huge_patch14_224': _cfg(url=''),
'vit_giant_patch14_224': _cfg(url=''),
'vit_gigantic_patch14_224': _cfg(url=''),
# patch models, imagenet21k (weights from official Google JAX impl) # patch models, imagenet21k (weights from official Google JAX impl)
'vit_tiny_patch16_224_in21k': _cfg( 'vit_tiny_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
@ -715,6 +719,33 @@ def vit_base_patch32_sam_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_huge_patch14_224(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
"""
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_giant_patch14_224(pretrained=False, **kwargs):
""" ViT-Giant model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
"""
model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_gigantic_patch14_224(pretrained=False, **kwargs):
""" ViT-Gigantic model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
"""
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16). """ ViT-Tiny (Vit-Ti/16).

@ -11,7 +11,7 @@ A PyTorch implement of the Hybrid Vision Transformers as described in:
NOTE These hybrid model definitions depend on code in vision_transformer.py. NOTE These hybrid model definitions depend on code in vision_transformer.py.
They were moved here to keep file sizes sane. They were moved here to keep file sizes sane.
Hacked together by / Copyright 2021 Ross Wightman Hacked together by / Copyright 2020, Ross Wightman
""" """
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial

@ -1,10 +1,12 @@
""" Cross-Covariance Image Transformer (XCiT) in PyTorch """ Cross-Covariance Image Transformer (XCiT) in PyTorch
Same as the official implementation, with some minor adaptations.
- https://github.com/facebookresearch/xcit/blob/master/xcit.py
Paper: Paper:
- https://arxiv.org/abs/2106.09681 - https://arxiv.org/abs/2106.09681
Same as the official implementation, with some minor adaptations, original copyright below
- https://github.com/facebookresearch/xcit/blob/master/xcit.py
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
""" """
# Copyright (c) 2015-present, Facebook, Inc. # Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved. # All rights reserved.

@ -1 +1 @@
__version__ = '0.5.2' __version__ = '0.5.4'

Loading…
Cancel
Save