Compare commits

...

9 Commits

@ -23,11 +23,14 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New
### Jan 6, 2022
* Version 0.5.2 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon....
* Tried training a few small / mobile optimized models, a few are good so far, more on the way...
### Jan 14, 2022
* Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon....
* Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features
* Tried training a few small (~1.8-3M param) / mobile optimized models, a few are good so far, more on the way...
* `mnasnet_small` - 65.6 top-1
* `lcnet_100` - 72.1 top-1
* `mobilenetv2_050` - 65.9
* `lcnet_100/075/050` - 72.1 / 68.8 / 63.1
* `semnasnet_075` - 73
* `fbnetv3_b/d/g` - 79.1 / 79.7 / 82.0
* TinyNet models added by [rsomani95](https://github.com/rsomani95)
* LCNet added via MobileNetV3 architecture

@ -28,12 +28,12 @@ NON_STD_FILTERS = [
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
if 'GITHUB_ACTIONS' in os.environ:
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*']
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*']
else:
EXCLUDE_FILTERS = []
@ -255,7 +255,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
EXCLUDE_JIT_FILTERS = [
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point
'vit_large_*', 'vit_huge_*',
'vit_large_*', 'vit_huge_*', 'vit_gi*',
]
@ -334,7 +334,7 @@ def _create_fx_model(model, train=False):
return fx_model
EXCLUDE_FX_FILTERS = []
EXCLUDE_FX_FILTERS = ['vit_gi*']
# not enough memory to run fx on more models than other tests
if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FX_FILTERS += [

@ -15,7 +15,7 @@ Papers:
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
Hacked together by / Copyright 2020 Ross Wightman
Hacked together by / Copyright 2019, Ross Wightman
"""
import random
import math

@ -1,6 +1,6 @@
""" Quick n Simple Image Folder, Tarfile based DataSet
Hacked together by / Copyright 2020 Ross Wightman
Hacked together by / Copyright 2019, Ross Wightman
"""
import torch.utils.data as data
import os

@ -1,3 +1,7 @@
""" Dataset Factory
Hacked together by / Copyright 2021, Ross Wightman
"""
import os
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder

@ -3,7 +3,7 @@
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
Hacked together by / Copyright 2021 Ross Wightman
Hacked together by / Copyright 2019, Ross Wightman
"""
import random
from functools import partial

@ -8,7 +8,7 @@ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Fea
Code Reference:
CutMix: https://github.com/clovaai/CutMix-PyTorch
Hacked together by / Copyright 2020 Ross Wightman
Hacked together by / Copyright 2019, Ross Wightman
"""
import numpy as np
import torch

@ -3,7 +3,7 @@
Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0
Copyright Zhun Zhong & Liang Zheng
Hacked together by / Copyright 2020 Ross Wightman
Hacked together by / Copyright 2019, Ross Wightman
"""
import random
import math

@ -1,7 +1,7 @@
""" Transforms Factory
Factory methods for building image transforms for use with TIMM (PyTorch Image Models)
Hacked together by / Copyright 2020 Ross Wightman
Hacked together by / Copyright 2019, Ross Wightman
"""
import math

@ -5,6 +5,7 @@ from .cait import *
from .coat import *
from .convit import *
from .convmixer import *
from .convnext import *
from .crossvit import *
from .cspnet import *
from .densenet import *

@ -4,6 +4,7 @@ Paper: 'Going deeper with Image Transformers' - https://arxiv.org/abs/2103.17239
Original code and weights from https://github.com/facebookresearch/deit, copyright below
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
"""
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.

@ -9,6 +9,8 @@
Paper link: https://arxiv.org/abs/2103.10697
Original code: https://github.com/facebookresearch/convit, original copyright below
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
"""
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.

@ -0,0 +1,368 @@
""" ConvNeXt
Paper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below
Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
"""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the MIT license
from collections import OrderedDict
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_module
from .helpers import named_apply, build_model_with_cfg
from .layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp
from .registry import register_model
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = dict(
convnext_tiny=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth"),
convnext_small=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth"),
convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"),
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
convnext_tiny_hnf=_cfg(url=''),
convnext_base_in22k=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
convnext_large_in22k=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841),
convnext_xlarge_in22k=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841),
)
def _is_contiguous(tensor: torch.Tensor) -> bool:
# jit is oh so lovely :/
# if torch.jit.is_tracing():
# return True
if torch.jit.is_scripting():
return tensor.is_contiguous()
else:
return tensor.is_contiguous(memory_format=torch.contiguous_format)
@register_notrace_module
class LayerNorm2d(nn.LayerNorm):
r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
"""
def __init__(self, normalized_shape, eps=1e-6):
super().__init__(normalized_shape, eps=eps)
def forward(self, x) -> torch.Tensor:
if _is_contiguous(x):
return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
else:
s, u = torch.var_mean(x, dim=1, keepdim=True)
x = (x - u) * torch.rsqrt(s + self.eps)
x = x * self.weight[:, None, None] + self.bias[:, None, None]
return x
class ConvNeXtBlock(nn.Module):
""" ConvNeXt Block
There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=True, mlp_ratio=4, norm_layer=None):
super().__init__()
if not norm_layer:
norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
mlp_layer = ConvMlp if conv_mlp else Mlp
self.use_conv_mlp = conv_mlp
self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.norm = norm_layer(dim)
self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU)
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
x = self.conv_dw(x)
if self.use_conv_mlp:
x = self.norm(x)
x = self.mlp(x)
else:
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = self.mlp(x)
x = x.permute(0, 3, 1, 2)
if self.gamma is not None:
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
x = self.drop_path(x) + shortcut
return x
class ConvNeXtStage(nn.Module):
def __init__(
self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=True,
norm_layer=None, cl_norm_layer=None, cross_stage=False):
super().__init__()
if in_chs != out_chs or stride > 1:
self.downsample = nn.Sequential(
norm_layer(in_chs),
nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride),
)
else:
self.downsample = nn.Identity()
dp_rates = dp_rates or [0.] * depth
self.blocks = nn.Sequential(*[ConvNeXtBlock(
dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp,
norm_layer=norm_layer if conv_mlp else cl_norm_layer)
for j in range(depth)]
)
def forward(self, x):
x = self.downsample(x)
x = self.blocks(x)
return x
class ConvNeXt(nn.Module):
r""" ConvNeXt
A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_rate (float): Head dropout rate
drop_path_rate (float): Stochastic depth rate. Default: 0.
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(
self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4,
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=True,
head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0.,
):
super().__init__()
assert output_stride == 32
if norm_layer is None:
norm_layer = partial(LayerNorm2d, eps=1e-6)
cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
else:
assert conv_mlp,\
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
cl_norm_layer = norm_layer
self.num_classes = num_classes
self.drop_rate = drop_rate
self.feature_info = []
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size),
norm_layer(dims[0])
)
self.stages = nn.Sequential()
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
curr_stride = patch_size
prev_chs = dims[0]
stages = []
# 4 feature resolution stages, each consisting of multiple residual blocks
for i in range(4):
stride = 2 if i > 0 else 1
# FIXME support dilation / output_stride
curr_stride *= stride
out_chs = dims[i]
stages.append(ConvNeXtStage(
prev_chs, out_chs, stride=stride,
depth=depths[i], dp_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp,
norm_layer=norm_layer, cl_norm_layer=cl_norm_layer)
)
prev_chs = out_chs
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.num_features = prev_chs
if head_norm_first:
# norm -> global pool -> fc ordering, like most other nets (not compat with FB weights)
self.norm_pre = norm_layer(self.num_features) # final norm layer, before pooling
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
else:
# pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
self.norm_pre = nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
]))
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes=0, global_pool='avg'):
if isinstance(self.head, ClassifierHead):
# norm -> global pool -> fc
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
else:
# pool -> norm -> fc
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', self.head.norm),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
]))
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.norm_pre(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _init_weights(module, name=None, head_init_scale=1.0):
if isinstance(module, nn.Conv2d):
trunc_normal_(module.weight, std=.02)
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02)
nn.init.constant_(module.bias, 0)
if name and 'head.' in name:
module.weight.data.mul_(head_init_scale)
module.bias.data.mul_(head_init_scale)
def checkpoint_filter_fn(state_dict, model):
""" Remap FB checkpoints -> timm """
if 'model' in state_dict:
state_dict = state_dict['model']
out_dict = {}
import re
for k, v in state_dict.items():
k = k.replace('downsample_layers.0.', 'stem.')
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
k = k.replace('dwconv', 'conv_dw')
k = k.replace('pwconv', 'mlp.fc')
k = k.replace('head.', 'head.fc.')
if k.startswith('norm.'):
k = k.replace('norm', 'head.norm')
if v.ndim == 2 and 'head' not in k:
model_shape = model.state_dict()[k].shape
v = v.reshape(model_shape)
out_dict[k] = v
return out_dict
def _create_convnext(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(
ConvNeXt, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
**kwargs)
return model
@register_model
def convnext_tiny(pretrained=False, **kwargs):
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_tiny_hnf(pretrained=False, **kwargs):
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, **kwargs)
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_small(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
model = _create_convnext('convnext_small', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_base(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_base', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_large(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
model = _create_convnext('convnext_large', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_base_in22k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_base_in22k', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_large_in22k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
model = _create_convnext('convnext_large_in22k', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_xlarge_in22k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], conv_mlp=False, **kwargs)
model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args)
return model

@ -12,6 +12,8 @@ Paper link: https://arxiv.org/abs/2103.14899
Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py
NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
"""
# Copyright IBM All Rights Reserved.

@ -409,7 +409,7 @@ class CspNet(nn.Module):
def _create_cspnet(variant, pretrained=False, **kwargs):
cfg_variant = variant.split('_')[0]
# NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5]
out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5) if 'darknet' in variant else (0, 1, 2, 3, 4))
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5) if 'darknet' in variant else (0, 1, 2, 3, 4))
return build_model_with_cfg(
CspNet, variant, pretrained,
default_cfg=default_cfgs[variant],

@ -33,7 +33,7 @@ The majority of the above models (EfficientNet*, MixNet, MnasNet) and original w
by Mingxing Tan, Quoc Le, and other members of their Google Brain team. Thanks for consistently releasing
the models and weights open source!
Hacked together by / Copyright 2021 Ross Wightman
Hacked together by / Copyright 2019, Ross Wightman
"""
from functools import partial
from typing import List
@ -73,7 +73,8 @@ default_cfgs = {
'mnasnet_140': _cfg(url=''),
'semnasnet_050': _cfg(url=''),
'semnasnet_075': _cfg(url=''),
'semnasnet_075': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/semnasnet_075-18710866.pth'),
'semnasnet_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'),
'semnasnet_140': _cfg(url=''),
@ -83,7 +84,9 @@ default_cfgs = {
'mobilenetv2_035': _cfg(
url=''),
'mobilenetv2_050': _cfg(
url=''),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_050-3d30d450.pth',
interpolation='bicubic',
),
'mobilenetv2_075': _cfg(
url=''),
'mobilenetv2_100': _cfg(
@ -718,7 +721,7 @@ def _gen_mobilenet_v2(
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
num_features=1280 if fix_stem_head else round_chs_fn(1280),
num_features=1280 if fix_stem_head else max(1280, round_chs_fn(1280)),
stem_size=32,
fix_stem=fix_stem_head,
round_chs_fn=round_chs_fn,

@ -1,6 +1,6 @@
""" EfficientNet, MobileNetV3, etc Blocks
Hacked together by / Copyright 2020 Ross Wightman
Hacked together by / Copyright 2019, Ross Wightman
"""
import torch

@ -3,7 +3,7 @@
Assembles EfficieNet and related network feature blocks from string definitions.
Handles stride, dilation calculations, and selects feature extraction points.
Hacked together by / Copyright 2020 Ross Wightman
Hacked together by / Copyright 2019, Ross Wightman
"""
import logging

@ -21,7 +21,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct

@ -14,7 +14,7 @@ Adapted from official impl at https://github.com/facebookresearch/LeViT, origina
This version combines both conv/linear models and fixes torchscript compatibility.
Modifications by/coyright Copyright 2021 Ross Wightman
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
"""
# Copyright (c) 2015-present, Facebook, Inc.

@ -1,11 +1,10 @@
""" MobileNet V3
A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.
Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by / Copyright 2021 Ross Wightman
Hacked together by / Copyright 2019, Ross Wightman
"""
from functools import partial
from typing import List
@ -86,8 +85,14 @@ default_cfgs = {
input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95),
"lcnet_035": _cfg(),
"lcnet_050": _cfg(),
"lcnet_075": _cfg(),
"lcnet_050": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth',
interpolation='bicubic',
),
"lcnet_075": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth',
interpolation='bicubic',
),
"lcnet_100": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth',
interpolation='bicubic',
@ -111,8 +116,9 @@ class MobileNetV3(nn.Module):
* LCNet - https://arxiv.org/abs/2109.15099
"""
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
def __init__(
self, block_args, num_classes=1000, in_chans=3, stem_size=16, fix_stem=False, num_features=1280,
head_bias=True, pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
super(MobileNetV3, self).__init__()
act_layer = act_layer or nn.ReLU
@ -123,6 +129,7 @@ class MobileNetV3(nn.Module):
self.drop_rate = drop_rate
# Stem
if not fix_stem:
stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size)
@ -189,8 +196,8 @@ class MobileNetV3Features(nn.Module):
"""
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=True,
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
super(MobileNetV3Features, self).__init__()
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d
@ -198,6 +205,7 @@ class MobileNetV3Features(nn.Module):
self.drop_rate = drop_rate
# Stem
if not fix_stem:
stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size)
@ -382,6 +390,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
block_args=decode_arch_def(arch_def),
num_features=num_features,
stem_size=16,
fix_stem=channel_multiplier < 0.75,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=act_layer,

@ -4,7 +4,8 @@ This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-C
additional dropout and dynamic global avg/max pool.
ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman
Copyright 2020 Ross Wightman
Copyright 2019, Ross Wightman
"""
import math
from functools import partial

@ -4,6 +4,7 @@ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shi
Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
"""
# --------------------------------------------------------
# Swin Transformer

@ -180,7 +180,7 @@ def _filter_fn(state_dict):
def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG:
cfg = variant.split('_')[0]
# NOTE: VGG is one of few models with stride==1 features w/ 6 out_indices [0..5]
out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5))
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5))
model = build_model_with_cfg(
VGG, variant, pretrained,
default_cfg=default_cfgs[variant],

@ -4,6 +4,7 @@ Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.1
From original at https://github.com/danczs/Visformer
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
"""
from copy import deepcopy

@ -20,7 +20,7 @@ for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
Hacked together by / Copyright 2021 Ross Wightman
Hacked together by / Copyright 2020, Ross Wightman
"""
import math
import logging
@ -105,6 +105,10 @@ default_cfgs = {
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_huge_patch14_224': _cfg(url=''),
'vit_giant_patch14_224': _cfg(url=''),
'vit_gigantic_patch14_224': _cfg(url=''),
# patch models, imagenet21k (weights from official Google JAX impl)
'vit_tiny_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
@ -715,6 +719,33 @@ def vit_base_patch32_sam_224(pretrained=False, **kwargs):
return model
@register_model
def vit_huge_patch14_224(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
"""
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_giant_patch14_224(pretrained=False, **kwargs):
""" ViT-Giant model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
"""
model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_gigantic_patch14_224(pretrained=False, **kwargs):
""" ViT-Gigantic model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
"""
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16).

@ -11,7 +11,7 @@ A PyTorch implement of the Hybrid Vision Transformers as described in:
NOTE These hybrid model definitions depend on code in vision_transformer.py.
They were moved here to keep file sizes sane.
Hacked together by / Copyright 2021 Ross Wightman
Hacked together by / Copyright 2020, Ross Wightman
"""
from copy import deepcopy
from functools import partial

@ -1,10 +1,12 @@
""" Cross-Covariance Image Transformer (XCiT) in PyTorch
Same as the official implementation, with some minor adaptations.
- https://github.com/facebookresearch/xcit/blob/master/xcit.py
Paper:
- https://arxiv.org/abs/2106.09681
Same as the official implementation, with some minor adaptations, original copyright below
- https://github.com/facebookresearch/xcit/blob/master/xcit.py
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
"""
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.

@ -1 +1 @@
__version__ = '0.5.2'
__version__ = '0.5.4'

Loading…
Cancel
Save