Merge remote-tracking branch 'origin/master' into bits_and_tpu

Ross Wightman 3 years ago
commit 3b6ba76126

@ -19,10 +19,22 @@ In addition to the sponsors at the link above, I've received hardware and/or clo
* Nvidia (
* TFRC (
I'm fortunate to be able to dedicate significant time and money of my own supporting this and other open source projects. However, as the projects increase in scope, outside support is needed to continue with the current trajectory of hardware, infrastructure, and electricty costs.
I'm fortunate to be able to dedicate significant time and money of my own supporting this and other open source projects. However, as the projects increase in scope, outside support is needed to continue with the current trajectory of cloud services, hardware, and electricity costs.
## What's New
### Oct 19, 2021
* ResNet strikes back ( weights added, plus any extra training components used. Model weights and some more details here (
* BCE loss and Repeated Augmentation support for RSB paper
* 4 series of ResNet based attention model experiments being added (implemented across These include all sorts of attention, from channel attn like SE, ECA to 2D QKV self-attention layers such as Halo, Bottlneck, Lambda. Details here (
* Working implementations of the following 2D self-attention modules (likely to be differences from paper or eventual official impl):
* Halo (
* Bottleneck Transformer (
* LambdaNetworks (
* A RegNetZ series of models with some attention experiments (being added to). These do not follow the paper ( in any way other than block architecture, details of official models are not available. See more here (
* ConvMixer (, CrossVit (, and BeiT ( architectures + weights added
* freeze/unfreeze helpers by [Alexander Soare](
### Aug 18, 2021
* Optimizer bonanza!
* Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/ `timm bits` [branch](

@ -38,6 +38,20 @@ try:
except AttributeError:
from deepspeed.profiling.flops_profiler import get_model_profile
has_deepspeed_profiling = True
except ImportError as e:
has_deepspeed_profiling = False
from fvcore.nn import FlopCountAnalysis, flop_count_str
has_fvcore_profiling = True
except ImportError as e:
FlopCountAnalysis = None
has_fvcore_profiling = False
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate')
@ -67,6 +81,8 @@ parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--use-train-size', action='store_true', default=False,
help='Run inference at train size, not test-input-size if it exists.')
parser.add_argument('--num-classes', type=int, default=None,
help='Number classes in dataset')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
@ -81,6 +97,7 @@ parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
# train optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
@ -139,10 +156,33 @@ def resolve_precision(precision: str):
return use_amp, model_dtype, data_dtype
def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
macs, _ = get_model_profile(
input_res=(batch_size,) + input_size, # input shape or input to the input_constructor
input_constructor=None, # if specified, a constructor taking input_res is used as input to the model
print_profile=detailed, # prints the model graph with the measured profile attached to each module
detailed=detailed, # print the detailed profile
warm_up=10, # the number of warm-ups before measuring the time of each module
as_string=False, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
output_file=None, # path to the output file. If None, the profiler prints to stdout.
ignore_modules=None) # the list of modules to ignore in the profiling
return macs
def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
fca = FlopCountAnalysis(model, torch.ones((batch_size,) + input_size, device=device, dtype=dtype))
if detailed:
fcs = flop_count_str(fca)
class BenchmarkRunner:
def __init__(
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
num_warm_iter=10, num_bench_iter=50, **kwargs):
num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
self.model_name = model_name
self.detail = detail
self.device = device
@ -166,7 +206,7 @@ class BenchmarkRunner:
if torchscript:
self.model = torch.jit.script(self.model)
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=True)
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
self.input_size = data_config['input_size']
self.batch_size = kwargs.pop('batch_size', 256)
@ -234,6 +274,13 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
param_count=round(self.param_count / 1e6, 2),
if has_deepspeed_profiling:
macs = profile_deepspeed(self.model, self.input_size)
results['gmacs'] = round(macs / 1e9, 2)
elif has_fvcore_profiling:
macs = profile_fvcore(self.model, self.input_size)
results['gmacs'] = round(macs / 1e9, 2)
f"Inference benchmark of {self.model_name} done. "
f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/step")
@ -361,6 +408,44 @@ class TrainBenchmarkRunner(BenchmarkRunner):
return results
class ProfileRunner(BenchmarkRunner):
def __init__(self, model_name, device='cuda', profiler='', **kwargs):
super().__init__(model_name=model_name, device=device, **kwargs)
if not profiler:
if has_deepspeed_profiling:
profiler = 'deepspeed'
elif has_fvcore_profiling:
profiler = 'fvcore'
assert profiler, "One of deepspeed or fvcore needs to be installed for profiling to work."
self.profiler = profiler
def run(self):
f'Running profiler on {self.model_name} w/ '
f'input size {self.input_size} and batch size {self.batch_size}.')
macs = 0
if self.profiler == 'deepspeed':
macs = profile_deepspeed(self.model, self.input_size, batch_size=self.batch_size, detailed=True)
elif self.profiler == 'fvcore':
macs = profile_fvcore(self.model, self.input_size, batch_size=self.batch_size, detailed=True)
results = dict(
gmacs=round(macs / 1e9, 2),
param_count=round(self.param_count / 1e6, 2),
f"Profile of {self.model_name} done. "
f"{results['gmacs']:.2f} GMACs, {results['param_count']:.2f} M params.")
return results
def decay_batch_exp(batch_size, factor=0.5, divisor=16):
out_batch_size = batch_size * factor
if out_batch_size > divisor:
@ -409,6 +494,16 @@ def benchmark(args):
elif args.bench == 'train':
bench_fns = TrainBenchmarkRunner,
prefixes = 'train',
elif args.bench.startswith('profile'):
# specific profiler used if included in bench mode string, otherwise default to deepspeed, fallback to fvcore
if 'deepspeed' in args.bench:
assert has_deepspeed_profiling, "deepspeed must be installed to use deepspeed flop counter"
bench_kwargs['profiler'] = 'deepspeed'
elif 'fvcore' in args.bench:
assert has_fvcore_profiling, "fvcore must be installed to use fvcore flop counter"
bench_kwargs['profiler'] = 'fvcore'
bench_fns = ProfileRunner,
batch_size = 1
model_results = OrderedDict(model=model)
for prefix, bench_fn in zip(prefixes, bench_fns):
@ -456,16 +551,18 @@ def main():
except KeyboardInterrupt as e:
sort_key = 'train_samples_per_sec' if 'train' in args.bench else 'infer_samples_per_sec'
sort_key = 'infer_samples_per_sec'
if 'train' in args.bench:
sort_key = 'train_samples_per_sec'
elif 'profile' in args.bench:
sort_key = 'infer_gmacs'
results = sorted(results, key=lambda x: x[sort_key], reverse=True)
if len(results):
write_results(results_file, results)
import json
json_str = json.dumps(results, indent=4)
results = benchmark(args)
json_str = json.dumps(results, indent=4)
def write_results(results_file, results):

@ -0,0 +1,57 @@
from torch.nn.modules.batchnorm import BatchNorm2d
from torchvision.ops.misc import FrozenBatchNorm2d
import timm
from timm.utils.model import freeze, unfreeze
def test_freeze_unfreeze():
model = timm.create_model('resnet18')
# Freeze all
# Check top level module
assert model.fc.weight.requires_grad == False
# Check submodule
assert model.layer1[0].conv1.weight.requires_grad == False
# Check BN
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
# Unfreeze all
# Check top level module
assert model.fc.weight.requires_grad == True
# Check submodule
assert model.layer1[0].conv1.weight.requires_grad == True
# Check BN
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
# Freeze some
freeze(model, ['layer1', 'layer2.0'])
# Check frozen
assert model.layer1[0].conv1.weight.requires_grad == False
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
assert model.layer2[0].conv1.weight.requires_grad == False
# Check not frozen
assert model.layer3[0].conv1.weight.requires_grad == True
assert isinstance(model.layer3[0].bn1, BatchNorm2d)
assert model.layer2[1].conv1.weight.requires_grad == True
# Unfreeze some
unfreeze(model, ['layer1', 'layer2.0'])
# Check not frozen
assert model.layer1[0].conv1.weight.requires_grad == True
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
assert model.layer2[0].conv1.weight.requires_grad == True
# Freeze/unfreeze BN
# From root
freeze(model, ['layer1.0.bn1'])
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
unfreeze(model, ['layer1.0.bn1'])
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
# From direct parent
freeze(model.layer1[0], ['bn1'])
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
unfreeze(model.layer1[0], ['bn1'])
assert isinstance(model.layer1[0].bn1, BatchNorm2d)

@ -216,7 +216,7 @@ class Mixup:
lam = self._mix_pair(x)
lam = self._mix_batch(x)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
return x, target

@ -4,6 +4,7 @@ from .byobnet import *
from .cait import *
from .coat import *
from .convit import *
from .convmixer import *
from .crossvit import *
from .cspnet import *
from .densenet import *

@ -3,7 +3,7 @@
A flexible network w/ dataclass based config for stacking NN blocks including
self-attention (or similar) layers.
Currently used to implement experimential variants of:
Currently used to implement experimental variants of:
* Bottleneck Transformers
* Lambda ResNets
* HaloNets
@ -23,7 +23,7 @@ __all__ = []
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'crop_pct': 0.95, 'interpolation': 'bicubic',
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
'fixed_input_size': False, 'min_input_size': (3, 224, 224),
@ -34,35 +34,44 @@ def _cfg(url='', **kwargs):
default_cfgs = {
# GPU-Efficient (ResNet) weights
'botnet26t_256': _cfg(
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'botnet50ts_256': _cfg(
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'eca_botnext26ts_256': _cfg(
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'halonet26t': _cfg(
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'sehalonet33ts': _cfg(
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'halonet50ts': _cfg(
url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'eca_halonext26ts': _cfg(
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'lambda_resnet26t': _cfg(
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
'lambda_resnet50ts': _cfg(
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
'lambda_resnet26rpt_256': _cfg(
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
'haloregnetz_b': _cfg(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94),
'trionet50ts_256': _cfg(
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
@ -113,7 +122,7 @@ model_cfgs = dict(
@ -141,7 +150,7 @@ model_cfgs = dict(
self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16)
self_attn_kwargs=dict(block_size=8, halo_size=2)
@ -231,6 +240,46 @@ model_cfgs = dict(
# experimental
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
interleave_blocks(types=('bottle', 'self_attn'), every=3, d=12, c=192, s=2, gs=16, br=3),
ByoBlockCfg('self_attn', d=2, c=288, s=2, gs=16, br=3),
block_kwargs=dict(bottle_in=True, linear_out=True),
self_attn_kwargs=dict(block_size=7, halo_size=2, qk_ratio=0.33)
# experimental
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25,
self_attn_layer='lambda', self_attn_kwargs=dict(r=13)),
types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25,
self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25,
self_attn_layer='bottleneck', self_attn_kwargs=dict()),
@ -246,7 +295,6 @@ def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
def botnet26t_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet26-T backbone.
NOTE: this isn't performing well, may remove
kwargs.setdefault('img_size', 256)
return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs)
@ -255,7 +303,6 @@ def botnet26t_256(pretrained=False, **kwargs):
def botnet50ts_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet50-T backbone, silu act.
NOTE: this isn't performing well, may remove
kwargs.setdefault('img_size', 256)
return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs)
@ -264,7 +311,6 @@ def botnet50ts_256(pretrained=False, **kwargs):
def eca_botnext26ts_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet26-T backbone, silu act.
NOTE: this isn't performing well, may remove
kwargs.setdefault('img_size', 256)
return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs)
@ -326,3 +372,17 @@ def lambda_resnet26rpt_256(pretrained=False, **kwargs):
kwargs.setdefault('img_size', 256)
return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs)
def haloregnetz_b(pretrained=False, **kwargs):
""" Halo + RegNetZ
return _create_byoanet('haloregnetz_b', pretrained=pretrained, **kwargs)
def trionet50ts_256(pretrained=False, **kwargs):
""" TrioNet
return _create_byoanet('trionet50ts_256', 'trionet50ts', pretrained=pretrained, **kwargs)

@ -137,15 +137,15 @@ default_cfgs = {
# experimental models, likely to change ot be removed
'regnetz_b': _cfgr(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 224, 224), pool_size=(7, 7), first_conv='stem.conv'),
input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.94),
'regnetz_c': _cfgr(
imean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), first_conv='stem.conv'),
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.94),
'regnetz_d': _cfgr(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
@ -1096,18 +1096,16 @@ class SelfAttnBlock(nn.Module):
def forward(self, x):
shortcut = self.shortcut(x)
shortcut = x
x = self.conv1_1x1(x)
x = self.conv2_kxk(x)
x = self.self_attn(x)
x = self.post_attn(x)
x = self.conv3_1x1(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
if self.shortcut is not None:
x = x + self.shortcut(shortcut)
return self.act(x)
_block_registry = dict(

@ -0,0 +1,101 @@
import torch.nn as nn
from timm.models.registry import register_model
from .helpers import build_model_with_cfg
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .96, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
'first_conv': 'stem.0',
default_cfgs = {
'convmixer_1536_20': _cfg(url=''),
'convmixer_768_32': _cfg(url=''),
'convmixer_1024_20_ks9_p14': _cfg(url='')
class Residual(nn.Module):
def __init__(self, fn):
self.fn = fn
def forward(self, x):
return self.fn(x) + x
class ConvMixer(nn.Module):
def __init__(self, dim, depth, kernel_size=9, patch_size=7, in_chans=3, num_classes=1000, activation=nn.GELU, **kwargs):
self.num_classes = num_classes
self.num_features = dim
self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity()
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size),
self.blocks = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
nn.Conv2d(dim, dim, kernel_size=1),
) for i in range(depth)]
self.pooling = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
x = self.blocks(x)
x = self.pooling(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _create_convmixer(variant, pretrained=False, **kwargs):
return build_model_with_cfg(ConvMixer, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs)
def convmixer_1536_20(pretrained=False, **kwargs):
model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs)
return _create_convmixer('convmixer_1536_20', pretrained, **model_args)
def convmixer_768_32(pretrained=False, **kwargs):
model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, activation=nn.ReLU, **kwargs)
return _create_convmixer('convmixer_768_32', pretrained, **model_args)
def convmixer_1024_20_ks9_p14(pretrained=False, **kwargs):
model_args = dict(dim=1024, depth=20, kernel_size=9, patch_size=14, **kwargs)
return _create_convmixer('convmixer_1024_20_ks9_p14', pretrained, **model_args)

@ -20,7 +20,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .helpers import to_2tuple
from .helpers import to_2tuple, make_divisible
from .weight_init import trunc_normal_
@ -61,15 +61,14 @@ class PosEmbedRel(nn.Module):
self.height, self.width = to_2tuple(feat_size)
self.dim_head = dim_head
self.scale = scale
self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * self.scale)
self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale)
self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * scale)
self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * scale)
def forward(self, q):
B, num_heads, HW, _ = q.shape
B, HW, _ = q.shape
# relative logits in width dimension.
q = q.reshape(B * num_heads, self.height, self.width, -1)
q = q.reshape(B, self.height, self.width, -1)
rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
# relative logits in height dimension.
@ -77,35 +76,58 @@ class PosEmbedRel(nn.Module):
rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
rel_logits = rel_logits_h + rel_logits_w
rel_logits = rel_logits.reshape(B, num_heads, HW, HW)
rel_logits = rel_logits.reshape(B, HW, HW)
return rel_logits
class BottleneckAttn(nn.Module):
""" Bottleneck Attention
Paper: `Bottleneck Transformers for Visual Recognition` -
The internal dimensions of the attention module are controlled by the interaction of several arguments.
* the output dimension of the module is specified by dim_out, which falls back to input dim if not set
* the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
* the query and key (qk) dimensions are determined by
* num_heads * dim_head if dim_head is not None
* num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None
* as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used
dim (int): input dimension to the module
dim_out (int): output dimension of the module, same as dim if not set
stride (int): output stride of the module, avg pool used if stride == 2 (default: 1).
num_heads (int): parallel attention heads (default: 4)
dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
qkv_bias (bool): add bias to q, k, and v projections
scale_pos_embed (bool): scale the position embedding as well as Q @ K
def __init__(self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, qkv_bias=False):
def __init__(
self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None,
qk_ratio=1.0, qkv_bias=False, scale_pos_embed=False):
assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required'
dim_out = dim_out or dim
assert dim_out % num_heads == 0
self.num_heads = num_heads
self.dim_out = dim_out
self.dim_head = dim_out // num_heads
self.scale = self.dim_head ** -0.5
self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
self.dim_head_v = dim_out // self.num_heads
self.dim_out_qk = num_heads * self.dim_head_qk
self.dim_out_v = num_heads * self.dim_head_v
self.scale = self.dim_head_qk ** -0.5
self.scale_pos_embed = scale_pos_embed
self.qkv = nn.Conv2d(dim, self.dim_out * 3, 1, bias=qkv_bias)
self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
# NOTE I'm only supporting relative pos embedding for now
self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head, scale=self.scale)
self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale)
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
@ -114,16 +136,23 @@ class BottleneckAttn(nn.Module):
assert H == self.pos_embed.height
assert W == self.pos_embed.width
x = self.qkv(x) # B, 3 * num_heads * dim_head, H, W
x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2)
q, k, v = torch.split(x, self.num_heads, dim=1)
x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
# NOTE head vs channel split ordering in qkv projection was decided before I allowed qk to differ from v
# So, this is more verbose than if heads were before qkv splits, but throughput is not impacted.
q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1)
q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)
k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
attn_logits = (q @ k.transpose(-1, -2)) * self.scale
attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W
if self.scale_pos_embed:
attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W
attn = (q @ k) * self.scale + self.pos_embed(q)
attn = attn.softmax(dim=-1)
attn_out = attn_logits.softmax(dim=-1)
attn_out = (attn_out @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W
attn_out = self.pool(attn_out)
return attn_out
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
out = self.pool(out)
return out

@ -22,6 +22,7 @@ import torch
from torch import nn
import torch.nn.functional as F
from .helpers import make_divisible
from .weight_init import trunc_normal_
@ -73,9 +74,8 @@ class PosEmbedRel(nn.Module):
self.block_size = block_size
self.dim_head = dim_head
self.scale = scale
self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale)
self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale)
self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)
self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)
def forward(self, q):
B, BB, HW, _ = q.shape
@ -98,30 +98,63 @@ class HaloAttn(nn.Module):
Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
The internal dimensions of the attention module are controlled by the interaction of several arguments.
* the output dimension of the module is specified by dim_out, which falls back to input dim if not set
* the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
* the query and key (qk) dimensions are determined by
* num_heads * dim_head if dim_head is not None
* num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None
* as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used
dim (int): input dimension to the module
dim_out (int): output dimension of the module, same as dim if not set
feat_size (Tuple[int, int]): size of input feature_map (not used, for arg compat with bottle/lambda)
stride: output stride of the module, query downscaled if > 1 (default: 1).
num_heads: parallel attention heads (default: 8).
dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
block_size (int): size of blocks. (default: 8)
halo_size (int): size of halo overlap. (default: 3)
qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
qkv_bias (bool) : add bias to q, k, and v projections
avg_down (bool): use average pool downsample instead of strided query blocks
scale_pos_embed (bool): scale the position embedding as well as Q @ K
def __init__(
self, dim, dim_out=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, qkv_bias=False):
self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3,
qk_ratio=1.0, qkv_bias=False, avg_down=False, scale_pos_embed=False):
dim_out = dim_out or dim
assert dim_out % num_heads == 0
self.stride = stride
assert stride in (1, 2)
self.num_heads = num_heads
self.dim_head = dim_head or dim // num_heads
self.dim_qk = num_heads * self.dim_head
self.dim_v = dim_out
self.block_size = block_size
self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
self.dim_head_v = dim_out // self.num_heads
self.dim_out_qk = num_heads * self.dim_head_qk
self.dim_out_v = num_heads * self.dim_head_v
self.scale = self.dim_head_qk ** -0.5
self.scale_pos_embed = scale_pos_embed
self.block_size = self.block_size_ds = block_size
self.halo_size = halo_size
self.win_size = block_size + halo_size * 2 # neighbourhood window size
self.scale = self.dim_head ** -0.5
self.block_stride = 1
use_avg_pool = False
if stride > 1:
use_avg_pool = avg_down or block_size % stride != 0
self.block_stride = 1 if use_avg_pool else stride
self.block_size_ds = self.block_size // self.block_stride
# FIXME not clear if this stride behaviour is what the paper intended
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
# data in unfolded block form. I haven't wrapped my head around how that'd look.
self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias)
self.kv = nn.Conv2d(dim, self.dim_qk + self.dim_v, 1, bias=qkv_bias)
self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias)
self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias)
self.pos_embed = PosEmbedRel(
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
block_size=self.block_size_ds, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale)
self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity()
@ -139,41 +172,61 @@ class HaloAttn(nn.Module):
num_h_blocks = H // self.block_size
num_w_blocks = W // self.block_size
num_blocks = num_h_blocks * num_w_blocks
bs_stride = self.block_size // self.stride
q = self.q(x)
# unfold
q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4)
q = q.reshape(
-1, self.dim_head_qk,
num_h_blocks, self.block_size_ds, num_w_blocks, self.block_size_ds).permute(0, 1, 3, 5, 2, 4)
# B, num_heads * dim_head * block_size ** 2, num_blocks
q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3)
q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3)
# B * num_heads, num_blocks, block_size ** 2, dim_head
kv = self.kv(x)
# generate overlapping windows for kv
# Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not
# lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach.
# FIXME figure out how to switch impl between this and conv2d if XLA being used.
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size])
kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape(
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), num_blocks, -1).permute(0, 2, 3, 1)
# NOTE these two alternatives are equivalent, but above is the best balance of performance and clarity
# if self.stride_tricks:
# kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
# kv = kv.as_strided((
# B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
# stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
# else:
# kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
# kv = kv.reshape(
# B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
# B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads
attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied?
attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2
attn_out = attn_logits.softmax(dim=-1)
attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks
B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1)
k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1)
# B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v
if self.scale_pos_embed:
attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale
attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q)
# B * num_heads, num_blocks, block_size ** 2, win_size ** 2
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks
# fold
attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks)
attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride)
# B, dim_out, H // stride, W // stride
return attn_out
out = out.reshape(-1, self.block_size_ds, self.block_size_ds, num_h_blocks, num_w_blocks)
out = out.permute(0, 3, 1, 4, 2).contiguous().view(
B, self.dim_out_v, H // self.block_stride, W // self.block_stride)
# B, dim_out, H // block_stride, W // block_stride
out = self.pool(out)
return out
""" Three alternatives for overlapping windows.
`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold()
if is_xla:
# This code achieves haloing on PyTorch XLA with reasonable runtime trade-off, it is
# EXTREMELY slow for backward on a GPU though so I need a way of selecting based on environment.
WW = self.win_size ** 2
pw = torch.eye(WW, dtype=x.dtype, device=x.device).reshape(WW, 1, self.win_size, self.win_size)
kv = F.conv2d(kv.reshape(-1, 1, H, W), pw, stride=self.block_size, padding=self.halo_size)
elif self.stride_tricks:
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
kv = kv.as_strided((
B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
kv = kv.reshape(
B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3)

@ -24,7 +24,7 @@ import torch
from torch import nn
import torch.nn.functional as F
from .helpers import to_2tuple
from .helpers import to_2tuple, make_divisible
from .weight_init import trunc_normal_
@ -44,28 +44,46 @@ class LambdaLayer(nn.Module):
NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add.
The internal dimensions of the lambda module are controlled via the interaction of several arguments.
* the output dimension of the module is specified by dim_out, which falls back to input dim if not set
* the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
* the query (q) and key (k) dimension are determined by
* dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None
* q = num_heads * dim_head, k = dim_head
* as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set
dim (int): input dimension to the module
dim_out (int): output dimension of the module, same as dim if not set
feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W
stride (int): output stride of the module, avg pool used if stride == 2
num_heads (int): parallel attention heads.
dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9)
qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
qkv_bias (bool): add bias to q, k, and v projections
def __init__(
dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False):
self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9,
qk_ratio=1.0, qkv_bias=False):
self.dim = dim
self.dim_out = dim_out or dim
self.dim_k = dim_head # query depth 'k'
dim_out = dim_out or dim
assert dim_out % num_heads == 0, ' should be divided by num_heads'
self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
self.num_heads = num_heads
assert self.dim_out % num_heads == 0, ' should be divided by num_heads'
self.dim_v = self.dim_out // num_heads # value depth 'v'
self.dim_v = dim_out // num_heads
self.qkv = nn.Conv2d(
num_heads * dim_head + dim_head + self.dim_v,
num_heads * self.dim_qk + self.dim_qk + self.dim_v,
kernel_size=1, bias=qkv_bias)
self.norm_q = nn.BatchNorm2d(num_heads * dim_head)
self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk)
self.norm_v = nn.BatchNorm2d(self.dim_v)
if r is not None:
# local lambda convolution for pos
self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0))
self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0))
self.pos_emb = None
self.rel_pos_indices = None
@ -74,7 +92,7 @@ class LambdaLayer(nn.Module):
feat_size = to_2tuple(feat_size)
rel_size = [2 * s - 1 for s in feat_size]
self.conv_lambda = None
self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_k))
self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk))
self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False)
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
@ -82,9 +100,9 @@ class LambdaLayer(nn.Module):
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.dim ** -0.5)
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in
if self.conv_lambda is not None:
trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5)
trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5)
if self.pos_emb is not None:
trunc_normal_(self.pos_emb, std=.02)
@ -93,17 +111,17 @@ class LambdaLayer(nn.Module):
M = H * W
qkv = self.qkv(x)
q, k, v = torch.split(qkv, [
self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1)
q = self.norm_q(q).reshape(B, self.num_heads, self.dim_k, M).transpose(-1, -2) # B, num_heads, M, K
self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1)
q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K
v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
k = F.softmax(k.reshape(B, self.dim_k, M), dim=-1) # B, K, M
k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M
content_lam = k @ v # B, K, V
content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
if self.pos_emb is None:
position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
# FIXME relative pos embedding path not fully verified
pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1)

@ -51,10 +51,10 @@ default_cfgs = {
interpolation='bicubic', first_conv='conv1.0'),
'resnet26t': _cfg(
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)),
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
'resnet50': _cfg(
interpolation='bicubic', crop_pct=0.95),
'resnet50d': _cfg(
interpolation='bicubic', first_conv='conv1.0'),

@ -105,13 +105,15 @@ default_cfgs = {
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic'),
'resnetv2_50': _cfg(
interpolation='bicubic', crop_pct=0.95),
'resnetv2_50d': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50t': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_101': _cfg(
interpolation='bicubic', crop_pct=0.95),
'resnetv2_101d': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_152': _cfg(
@ -470,7 +472,7 @@ def _create_resnetv2(variant, pretrained=False, **kwargs):
ResNetV2, variant, pretrained,
pretrained_custom_load='_bit' in variant,

@ -7,7 +7,7 @@ from .jit import set_jit_legacy
from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg
from .model import unwrap_model, get_state_dict
from .model import unwrap_model, get_state_dict, freeze, unfreeze
from .model_ema import ModelEma, ModelEmaV2
from .random import random_seed
from .summary import update_summary, get_outdir

@ -2,39 +2,38 @@
Hacked together by / Copyright 2020 Ross Wightman
from .model_ema import ModelEma
import torch
import fnmatch
_SUB_MODULE_ATTR = ('module', 'model')
import torch
from torchvision.ops.misc import FrozenBatchNorm2d
from .model_ema import ModelEma
def unwrap_model(model, recursive=True):
for attr in _SUB_MODULE_ATTR:
sub_module = getattr(model, attr, None)
if sub_module is not None:
return unwrap_model(sub_module) if recursive else sub_module
return model
def unwrap_model(model):
if isinstance(model, ModelEma):
return unwrap_model(model.ema)
return model.module if hasattr(model, 'module') else model
def get_state_dict(model, unwrap_fn=unwrap_model):
return unwrap_fn(model).state_dict()
def avg_sq_ch_mean(model, input, output):
"""calculate average channel square mean of output activations
return torch.mean(output.mean(axis=[0, 2, 3]) ** 2).item()
def avg_sq_ch_mean(model, input, output):
"calculate average channel square mean of output activations"
return torch.mean(output.mean(axis=[0,2,3])**2).item()
def avg_ch_var(model, input, output):
"""calculate average channel variance of output activations"""
return torch.mean(output.var(axis=[0, 2, 3])).item()
def avg_ch_var(model, input, output):
"calculate average channel variance of output activations"
return torch.mean(output.var(axis=[0,2,3])).item()\
def avg_ch_var_residual(model, input, output):
"""calculate average channel variance of output activations"""
return torch.mean(output.var(axis=[0, 2, 3])).item()
def avg_ch_var_residual(model, input, output):
"calculate average channel variance of output activations"
return torch.mean(output.var(axis=[0,2,3])).item()
class ActivationStatsHook:
@ -63,16 +62,15 @@ class ActivationStatsHook:
raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \
their lengths are different.")
self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns)
for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns):
for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns):
self.register_hook(hook_fn_loc, hook_fn)
def _create_hook(self, hook_fn):
def append_activation_stats(module, input, output):
out = hook_fn(module, input, output)
return append_activation_stats
def register_hook(self, hook_fn_loc, hook_fn):
for name, module in self.model.named_modules():
if not fnmatch.fnmatch(name, hook_fn_loc):
@ -80,9 +78,9 @@ class ActivationStatsHook:
def extract_spp_stats(model,
def extract_spp_stats(model,
input_shape=[8, 3, 224, 224]):
"""Extract average square channel mean and variance of activations during
forward pass to plot Signal Propogation Plots (SPP).
@ -90,8 +88,180 @@ def extract_spp_stats(model,
Example Usage:
x = torch.normal(0., 1., input_shape)
hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns)
_ = model(x)
return hook.stats
def freeze_batch_norm_2d(module):
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
module (torch.nn.Module): Any PyTorch module.
torch.nn.Module: Resulting module
Inspired by
res = module
if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine: = = = =
res.eps = module.eps
for name, child in module.named_children():
new_child = freeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def unfreeze_batch_norm_2d(module):
Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
recursively and submodules are converted in place.
module (torch.nn.Module): Any PyTorch module.
torch.nn.Module: Resulting module
Inspired by
res = module
if isinstance(module, FrozenBatchNorm2d):
res = torch.nn.BatchNorm2d(module.num_features)
if module.affine: = = = =
res.eps = module.eps
for name, child in module.named_children():
new_child = unfreeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'):
Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is
done in place.
root_module (nn.Module, optional): Root module relative to which the `submodules` are referenced.
submodules (list[str]): List of modules for which the parameters will be (un)frozen. They are to be provided as
named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list
means that the whole root module will be (un)frozen. Defaults to []
include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm 2d layers.
Defaults to `True`.
mode (bool): Whether to freeze ("freeze") or unfreeze ("unfreeze"). Defaults to `"freeze"`.
assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"'
if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
# Raise assertion here because we can't convert it in place
raise AssertionError(
"You have provided a batch norm layer as the `root module`. Please use "
"`timm.utils.model.freeze_batch_norm_2d` or `timm.utils.model.unfreeze_batch_norm_2d` instead.")
if isinstance(submodules, str):
submodules = [submodules]
named_modules = submodules
submodules = [root_module.get_submodule(m) for m in submodules]
if not(len(submodules)):
named_modules, submodules = list(zip(*root_module.named_children()))
for n, m in zip(named_modules, submodules):
# (Un)freeze parameters
for p in m.parameters():
p.requires_grad = False if mode == 'freeze' else True
if include_bn_running_stats:
# Helper to add submodule specified as a named_module
def _add_submodule(module, name, submodule):
split = name.rsplit('.', 1)
if len(split) > 1:
module.get_submodule(split[0]).add_module(split[1], submodule)
module.add_module(name, submodule)
# Freeze batch norm
if mode == 'freeze':
res = freeze_batch_norm_2d(m)
# It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't
# convert it in place, but will return the converted result. In this case `res` holds the converted
# result and we may try to re-assign the named module
if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
_add_submodule(root_module, n, res)
# Unfreeze batch norm
res = unfreeze_batch_norm_2d(m)
# Ditto. See note above in mode == 'freeze' branch
if isinstance(m, FrozenBatchNorm2d):
_add_submodule(root_module, n, res)
def freeze(root_module, submodules=[], include_bn_running_stats=True):
Freeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
root_module (nn.Module): Root module relative to which `submodules` are referenced.
submodules (list[str]): List of modules for which the parameters will be frozen. They are to be provided as
named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list
means that the whole root module will be frozen. Defaults to `[]`.
include_bn_running_stats (bool): Whether to also freeze the running statistics of `BatchNorm2d` and
`SyncBatchNorm` layers. These will be converted to `FrozenBatchNorm2d` in place. Hint: During fine tuning,
it's good practice to freeze batch norm stats. And note that these are different to the affine parameters
which are just normal PyTorch parameters. Defaults to `True`.
Hint: If you want to freeze batch norm ONLY, use `timm.utils.model.freeze_batch_norm_2d`.
>>> model = timm.create_model('resnet18')
>>> # Freeze up to and including layer2
>>> submodules = [n for n, _ in model.named_children()]
>>> print(submodules)
['conv1', 'bn1', 'act1', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'global_pool', 'fc']
>>> freeze(model, submodules[:submodules.index('layer2') + 1])
>>> # Check for yourself that it works as expected
>>> print(model.layer2[0].conv1.weight.requires_grad)
>>> print(model.layer3[0].conv1.weight.requires_grad)
>>> # Unfreeze
>>> unfreeze(model)
_freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="freeze")
def unfreeze(root_module, submodules=[], include_bn_running_stats=True):
Unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
root_module (nn.Module): Root module relative to which `submodules` are referenced.
submodules (list[str]): List of submodules for which the parameters will be (un)frozen. They are to be provided
as named modules relative to the root module (accessible via `root_module.named_modules()`). An empty
list means that the whole root module will be unfrozen. Defaults to `[]`.
include_bn_running_stats (bool): Whether to also unfreeze the running statistics of `FrozenBatchNorm2d` layers.
These will be converted to `BatchNorm2d` in place. Defaults to `True`.
See example in docstring for `freeze`.
_freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze")

@ -1 +1 @@
__version__ = '0.4.13'
__version__ = '0.5.0'
