From 908563d060d3c7f2e46583e0e431ab5331f7e558 Mon Sep 17 00:00:00 2001 From: Shoufa Chen Date: Sun, 26 Sep 2021 12:32:22 +0800 Subject: [PATCH 01/33] fix `use_amp` Fix https://github.com/rwightman/pytorch-image-models/issues/881 --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 3943c7d0..84d8b2ea 100755 --- a/train.py +++ b/train.py @@ -397,7 +397,7 @@ def main(): # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn - if has_apex and use_amp != 'native': + if has_apex and use_amp == 'apex': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: @@ -451,7 +451,7 @@ def main(): # setup distributed training if args.distributed: - if has_apex and use_amp != 'native': + if has_apex and use_amp == 'apex': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") From 51eaf9360d28528fea924fa037236fa993bc81f8 Mon Sep 17 00:00:00 2001 From: leondgarse Date: Thu, 30 Sep 2021 18:30:48 +0800 Subject: [PATCH 02/33] Remove a duplicate layer creation in byobnet.py `self.conv2_kxk` is repeated in `byobnet.py`. Remove the duplicate code. --- timm/models/byobnet.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index edce355a..59105c44 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -797,9 +797,6 @@ class BottleneckBlock(nn.Module): self.shortcut = nn.Identity() self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) - self.conv2_kxk = layers.conv_norm_act( - mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], - groups=groups, drop_block=drop_block) self.conv2_kxk = layers.conv_norm_act( mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], groups=groups, drop_block=drop_block) From 0cb8ea432ce1648ba28171080216d84544b62d1d Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 21 Sep 2021 12:46:42 +0100 Subject: [PATCH 03/33] wip --- timm/models/layers/norm.py | 40 +++++++++++++++++++++++ timm/utils/model.py | 65 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 103 insertions(+), 2 deletions(-) diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index aace107b..fc500807 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torchvision class GroupNorm(nn.GroupNorm): @@ -22,3 +23,42 @@ class LayerNorm2d(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm( x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + + +class FrozenBatchNorm2d(torchvision.ops.misc.FrozenBatchNorm2d): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Inherits from torchvision while adding the `convert_frozen_batchnorm` from + https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/batch_norm.py + """ + + @classmethod + def convert_frozen_batchnorm(cls, module): + """ + Converts all BatchNorm layers of provided module into FrozenBatchNorm. If `module` is a type of BatchNorm, it + converts it into FrozenBatchNorm. Otherwise, the module is walked recursively and BatchNorm type layers are + converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. It doesn't have to be a BatchNorm variant in itself. + + Returns: + torch.nn.Module: Resulting module + """ + res = module + if isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = cls(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = cls.convert_frozen_batchnorm(child) + if new_child is not child: + res.add_module(name, new_child) + return res + diff --git a/timm/utils/model.py b/timm/utils/model.py index bd46e2f4..d0fb69ed 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -2,10 +2,20 @@ Hacked together by / Copyright 2020 Ross Wightman """ -from .model_ema import ModelEma +from logging import root +from typing import Sequence +import re +import warnings + import torch import fnmatch +from torch.nn.modules import module + +from .model_ema import ModelEma +from timm.models.layers.norm import FrozenBatchNorm2d + + def unwrap_model(model): if isinstance(model, ModelEma): return unwrap_model(model.ema) @@ -89,4 +99,55 @@ def extract_spp_stats(model, hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns) _ = model(x) return hook.stats - \ No newline at end of file + + +def freeze(modules, root_module=None, include_bn_running_stats=True, mode=True): + """ + Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is + done in place. + Args: + modules (nn.Module or list[nn.Module] or str or list[str]): List of modules for which the parameters will be + (un)frozen. If a string or strings are provided these will be interpreted according to the named modules + of the provided ``root_module``. + root_module (nn.Module, optional): Root module relative to which named modules (accessible via + ``root_module.named_modules()``) are referenced. Must be provided if the `modules` argument is specified + with a string or strings. Defaults to `None`. + include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm layers. + Defaults to `True`. + mode (bool): Whether to freeze (`True`) or unfreeze (`False`). Defaults to `True`. + + TODO before finalizing PR: Implement unfreezing of batch norm + """ + + if not isinstance(modules, Sequence): + modules = [modules] + + if isinstance(modules[0], str): + assert root_module is not None, \ + "When providing strings for the `modules` argument, a `root_module` must be provided" + module_names = modules + modules = [root_module.get_submodule(m) for m in module_names] + + for n, m in zip(module_names, modules): + for p in m.parameters(): + p.requires_grad = (not mode) + if include_bn_running_stats: + res = FrozenBatchNorm2d.convert_frozen_batchnorm(m) + # It's possible that `m` is a type of BatchNorm in itself, in which case + # `FrozenBatchNorm2d.convert_frozen_batchnorm` won't convert it in place, but will return the converted + # result. In this case `res` holds the converted result and we may try to re-assign the named module + if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + if module_names is not None and root_module is not None: + root_module.add_module(n, res) + else: + raise RuntimeError( + "Could not freeze batch norm statistics due to a technical limitation. Hint: Try calling " + "`freeze` with a list of module names while providing a `root_module` argument.") + + +def unfreeze(modules, root_module=None, include_bn_running_stats=True): + """ + Idiomatic convenience function to call `freeze` with `mode == False`. See docstring of `freeze` for further + information. + """ + freeze(modules, root_module=root_module, include_bn_running_stats=include_bn_running_stats, mode=False) From 65c3d78b96c8adce220675841de199516b3041ff Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 2 Oct 2021 15:54:14 +0100 Subject: [PATCH 04/33] Freeze unfreeze functionality finalized. Tests added --- tests/test_utils.py | 60 ++++++++++++ timm/models/layers/norm.py | 40 -------- timm/utils/model.py | 195 +++++++++++++++++++++++++++++-------- 3 files changed, 214 insertions(+), 81 deletions(-) create mode 100644 tests/test_utils.py diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..3e11eacc --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,60 @@ +from torch.nn.modules.batchnorm import BatchNorm2d +from torchvision.ops.misc import FrozenBatchNorm2d + +import timm +from timm.utils.model import freeze, unfreeze + + +def test_freeze_unfreeze(): + model = timm.create_model('resnet18') + + # Freeze all + freeze(model) + # Check top level module + assert model.fc.weight.requires_grad == False + # Check submodule + assert model.layer1[0].conv1.weight.requires_grad == False + # Check BN + assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) + + # Unfreeze all + unfreeze(model) + # Check top level module + assert model.fc.weight.requires_grad == True + # Check submodule + assert model.layer1[0].conv1.weight.requires_grad == True + # Check BN + assert isinstance(model.layer1[0].bn1, BatchNorm2d) + + # Freeze some + freeze(model, ['layer1', 'layer2.0']) + # Check frozen + assert model.layer1[0].conv1.weight.requires_grad == False + assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) + assert model.layer2[0].conv1.weight.requires_grad == False + # Check not frozen + assert model.layer3[0].conv1.weight.requires_grad == True + assert isinstance(model.layer3[0].bn1, BatchNorm2d) + assert model.layer2[1].conv1.weight.requires_grad == True + + # Unfreeze some + unfreeze(model, ['layer1', 'layer2.0']) + # Check not frozen + assert model.layer1[0].conv1.weight.requires_grad == True + assert isinstance(model.layer1[0].bn1, BatchNorm2d) + assert model.layer2[0].conv1.weight.requires_grad == True + + # Freeze BN + # From root + freeze(model, ['layer1.0.bn1']) + assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) + # From direct parent + freeze(model.layer1[0], ['bn1']) + assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) + + # Unfreeze BN + unfreeze(model, ['layer1.0.bn1']) + assert isinstance(model.layer1[0].bn1, BatchNorm2d) + # From direct parent + unfreeze(model.layer1[0], ['bn1']) + assert isinstance(model.layer1[0].bn1, BatchNorm2d) \ No newline at end of file diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index fc500807..aace107b 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torchvision class GroupNorm(nn.GroupNorm): @@ -23,42 +22,3 @@ class LayerNorm2d(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm( x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) - - -class FrozenBatchNorm2d(torchvision.ops.misc.FrozenBatchNorm2d): - """ - BatchNorm2d where the batch statistics and the affine parameters are fixed. - - Inherits from torchvision while adding the `convert_frozen_batchnorm` from - https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/batch_norm.py - """ - - @classmethod - def convert_frozen_batchnorm(cls, module): - """ - Converts all BatchNorm layers of provided module into FrozenBatchNorm. If `module` is a type of BatchNorm, it - converts it into FrozenBatchNorm. Otherwise, the module is walked recursively and BatchNorm type layers are - converted in place. - - Args: - module (torch.nn.Module): Any PyTorch module. It doesn't have to be a BatchNorm variant in itself. - - Returns: - torch.nn.Module: Resulting module - """ - res = module - if isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): - res = cls(module.num_features) - if module.affine: - res.weight.data = module.weight.data.clone().detach() - res.bias.data = module.bias.data.clone().detach() - res.running_mean.data = module.running_mean.data - res.running_var.data = module.running_var.data - res.eps = module.eps - else: - for name, child in module.named_children(): - new_child = cls.convert_frozen_batchnorm(child) - if new_child is not child: - res.add_module(name, new_child) - return res - diff --git a/timm/utils/model.py b/timm/utils/model.py index d0fb69ed..c2786401 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -4,16 +4,12 @@ Hacked together by / Copyright 2020 Ross Wightman """ from logging import root from typing import Sequence -import re -import warnings import torch import fnmatch - -from torch.nn.modules import module +from torchvision.ops.misc import FrozenBatchNorm2d from .model_ema import ModelEma -from timm.models.layers.norm import FrozenBatchNorm2d def unwrap_model(model): @@ -99,55 +95,172 @@ def extract_spp_stats(model, hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns) _ = model(x) return hook.stats - -def freeze(modules, root_module=None, include_bn_running_stats=True, mode=True): + +def freeze_batch_norm_2d(module): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + + Returns: + torch.nn.Module: Resulting module + """ + res = module + if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = freeze_batch_norm_2d(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + +def unfreeze_batch_norm_2d(module): + """ + Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance + of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked + recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + + Returns: + torch.nn.Module: Resulting module + """ + res = module + if isinstance(module, FrozenBatchNorm2d): + res = torch.nn.BatchNorm2d(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = unfreeze_batch_norm_2d(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + +def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'): """ Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place. Args: - modules (nn.Module or list[nn.Module] or str or list[str]): List of modules for which the parameters will be - (un)frozen. If a string or strings are provided these will be interpreted according to the named modules - of the provided ``root_module``. - root_module (nn.Module, optional): Root module relative to which named modules (accessible via - ``root_module.named_modules()``) are referenced. Must be provided if the `modules` argument is specified - with a string or strings. Defaults to `None`. - include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm layers. + root_module (nn.Module, optional): Root module relative to which the `submodules` are referenced. + submodules (list[str]): List of modules for which the parameters will be (un)frozen. They are to be provided as + named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list + means that the whole root module will be (un)frozen. Defaults to [] + include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm 2d layers. Defaults to `True`. - mode (bool): Whether to freeze (`True`) or unfreeze (`False`). Defaults to `True`. - - TODO before finalizing PR: Implement unfreezing of batch norm + mode (bool): Whether to freeze ("freeze") or unfreeze ("unfreeze"). Defaults to `"freeze"`. """ - - if not isinstance(modules, Sequence): - modules = [modules] + assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"' + + if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + # Raise assertion here because we can't convert it in place + raise AssertionError( + "You have provided a batch norm layer as the `root module`. Please use " + "`timm.utils.model.freeze_batch_norm_2d` or `timm.utils.model.unfreeze_batch_norm_2d` instead.") - if isinstance(modules[0], str): - assert root_module is not None, \ - "When providing strings for the `modules` argument, a `root_module` must be provided" - module_names = modules - modules = [root_module.get_submodule(m) for m in module_names] + if isinstance(submodules, str): + submodules = [submodules] - for n, m in zip(module_names, modules): + named_modules = submodules + submodules = [root_module.get_submodule(m) for m in submodules] + + if not(len(submodules)): + named_modules, submodules = list(zip(*root_module.named_children())) + + for n, m in zip(named_modules, submodules): + # (Un)freeze parameters for p in m.parameters(): - p.requires_grad = (not mode) + p.requires_grad = (False if mode == 'freeze' else True) if include_bn_running_stats: - res = FrozenBatchNorm2d.convert_frozen_batchnorm(m) - # It's possible that `m` is a type of BatchNorm in itself, in which case - # `FrozenBatchNorm2d.convert_frozen_batchnorm` won't convert it in place, but will return the converted - # result. In this case `res` holds the converted result and we may try to re-assign the named module - if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): - if module_names is not None and root_module is not None: - root_module.add_module(n, res) + # Helper to add submodule specified as a named_module + def _add_submodule(module, name, submodule): + split = name.rsplit('.', 1) + if len(split) > 1: + module.get_submodule(split[0]).add_module(split[1], submodule) else: - raise RuntimeError( - "Could not freeze batch norm statistics due to a technical limitation. Hint: Try calling " - "`freeze` with a list of module names while providing a `root_module` argument.") + module.add_module(name, submodule) + # Freeze batch norm + if mode == 'freeze': + res = freeze_batch_norm_2d(m) + # It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't + # convert it in place, but will return the converted result. In this case `res` holds the converted + # result and we may try to re-assign the named module + if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + _add_submodule(root_module, n, res) + # Unfreeze batch norm + else: + res = unfreeze_batch_norm_2d(m) + # Ditto. See note above in mode == 'freeze' branch + if isinstance(m, FrozenBatchNorm2d): + _add_submodule(root_module, n, res) + + +def freeze(root_module, submodules=[], include_bn_running_stats=True): + """ + Freeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place. + Args: + root_module (nn.Module): Root module relative to which `submodules` are referenced. + submodules (list[str]): List of modules for which the parameters will be frozen. They are to be provided as + named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list + means that the whole root module will be frozen. Defaults to `[]`. + include_bn_running_stats (bool): Whether to also freeze the running statistics of `BatchNorm2d` and + `SyncBatchNorm` layers. These will be converted to `FrozenBatchNorm2d` in place. Hint: During fine tuning, + it's good practice to freeze batch norm stats. And note that these are different to the affine parameters + which are just normal PyTorch parameters. Defaults to `True`. + + Hint: If you want to freeze batch norm ONLY, use `timm.utils.model.freeze_batch_norm_2d`. + + Examples:: + + >>> model = timm.create_model('resnet18') + >>> # Freeze up to and including layer2 + >>> submodules = [n for n, _ in model.named_children()] + >>> print(submodules) + ['conv1', 'bn1', 'act1', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'global_pool', 'fc'] + >>> freeze(model, submodules[:submodules.index('layer2') + 1]) + >>> # Check for yourself that it works as expected + >>> print(model.layer2[0].conv1.weight.requires_grad) + False + >>> print(model.layer3[0].conv1.weight.requires_grad) + True + >>> # Unfreeze + >>> unfreeze(model) + """ + _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="freeze") -def unfreeze(modules, root_module=None, include_bn_running_stats=True): +def unfreeze(root_module, submodules=[], include_bn_running_stats=True): """ - Idiomatic convenience function to call `freeze` with `mode == False`. See docstring of `freeze` for further - information. + Unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place. + Args: + root_module (nn.Module): Root module relative to which `submodules` are referenced. + submodules (list[str]): List of submodules for which the parameters will be (un)frozen. They are to be provided + as named modules relative to the root module (accessible via `root_module.named_modules()`). An empty + list means that the whole root module will be unfrozen. Defaults to `[]`. + include_bn_running_stats (bool): Whether to also unfreeze the running statistics of `FrozenBatchNorm2d` layers. + These will be converted to `BatchNorm2d` in place. Defaults to `True`. + + See example in docstring for `freeze`. """ - freeze(modules, root_module=root_module, include_bn_running_stats=include_bn_running_stats, mode=False) + _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze") + \ No newline at end of file From 6d2acec1bb1fa40fa8f65771e2f18c805920d4e3 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 2 Oct 2021 16:10:11 +0100 Subject: [PATCH 05/33] Fix ordering of tests --- tests/test_utils.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 3e11eacc..b0f890d2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -44,17 +44,14 @@ def test_freeze_unfreeze(): assert isinstance(model.layer1[0].bn1, BatchNorm2d) assert model.layer2[0].conv1.weight.requires_grad == True - # Freeze BN + # Freeze/unfreeze BN # From root freeze(model, ['layer1.0.bn1']) assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) - # From direct parent - freeze(model.layer1[0], ['bn1']) - assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) - - # Unfreeze BN unfreeze(model, ['layer1.0.bn1']) assert isinstance(model.layer1[0].bn1, BatchNorm2d) # From direct parent + freeze(model.layer1[0], ['bn1']) + assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) unfreeze(model.layer1[0], ['bn1']) assert isinstance(model.layer1[0].bn1, BatchNorm2d) \ No newline at end of file From 007bc3932375a71e7c1e4aa0b7f0c79f3bb79f56 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 2 Oct 2021 15:51:42 -0700 Subject: [PATCH 06/33] Some halo and bottleneck attn code cleanup, add halonet50ts weights, use optimal crop ratios --- timm/models/byoanet.py | 9 ++-- timm/models/layers/bottleneck_attn.py | 12 ++--- timm/models/layers/halo_attn.py | 68 +++++++++++++++------------ 3 files changed, 48 insertions(+), 41 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 056813ef..f58b724c 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -3,7 +3,7 @@ A flexible network w/ dataclass based config for stacking NN blocks including self-attention (or similar) layers. -Currently used to implement experimential variants of: +Currently used to implement experimental variants of: * Bottleneck Transformers * Lambda ResNets * HaloNets @@ -46,15 +46,16 @@ default_cfgs = { 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet26t': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_256-9b4bf0b3.pth', - input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'sehalonet33ts': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'halonet50ts': _cfg( - url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_256_ra3-f07eab9f.pth', + input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'eca_halonext26ts': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_256-1e55880b.pth', - input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'lambda_resnet26t': _cfg( url='', diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index bf6af675..61859f9c 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -118,12 +118,12 @@ class BottleneckAttn(nn.Module): x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2) q, k, v = torch.split(x, self.num_heads, dim=1) - attn_logits = (q @ k.transpose(-1, -2)) * self.scale - attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W + attn = (q @ k.transpose(-1, -2)) * self.scale + attn = attn + self.pos_embed(q) # B, num_heads, H * W, H * W + attn = attn.softmax(dim=-1) - attn_out = attn_logits.softmax(dim=-1) - attn_out = (attn_out @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W - attn_out = self.pool(attn_out) - return attn_out + out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W + out = self.pool(out) + return out diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index d298fc0b..034c66a8 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -106,22 +106,23 @@ class HaloAttn(nn.Module): assert dim_out % num_heads == 0 self.stride = stride self.num_heads = num_heads - self.dim_head = dim_head or dim // num_heads - self.dim_qk = num_heads * self.dim_head - self.dim_v = dim_out + self.dim_head_qk = dim_head or dim_out // num_heads + self.dim_head_v = dim_out // self.num_heads + self.dim_out_qk = num_heads * self.dim_head_qk + self.dim_out_v = num_heads * self.dim_head_v self.block_size = block_size self.halo_size = halo_size self.win_size = block_size + halo_size * 2 # neighbourhood window size - self.scale = self.dim_head ** -0.5 + self.scale = self.dim_head_qk ** -0.5 # FIXME not clear if this stride behaviour is what the paper intended # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving # data in unfolded block form. I haven't wrapped my head around how that'd look. - self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias) - self.kv = nn.Conv2d(dim, self.dim_qk + self.dim_v, 1, bias=qkv_bias) + self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.stride, bias=qkv_bias) + self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias) self.pos_embed = PosEmbedRel( - block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale) + block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale) self.reset_parameters() @@ -143,37 +144,42 @@ class HaloAttn(nn.Module): q = self.q(x) # unfold - q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4) + q = q.reshape(-1, self.dim_head_qk, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4) # B, num_heads * dim_head * block_size ** 2, num_blocks - q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3) + q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3) # B * num_heads, num_blocks, block_size ** 2, dim_head kv = self.kv(x) # generate overlapping windows for kv kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]) kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape( - B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), num_blocks, -1).permute(0, 2, 3, 1) - # NOTE these two alternatives are equivalent, but above is the best balance of performance and clarity - # if self.stride_tricks: - # kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() - # kv = kv.as_strided(( - # B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), - # stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) - # else: - # kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) - # kv = kv.reshape( - # B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3) - k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1) - # B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads - - attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied? - attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2 - - attn_out = attn_logits.softmax(dim=-1) - attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks + B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1) + k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1) + # B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v + attn = (q @ k.transpose(-1, -2)) * self.scale + attn = attn + self.pos_embed(q) # B * num_heads, num_blocks, block_size ** 2, win_size ** 2 + attn = attn.softmax(dim=-1) + + out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks # fold - attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks) - attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride) + out = out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks) + out = out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_out_v, H // self.stride, W // self.stride) # B, dim_out, H // stride, W // stride - return attn_out + return out + + +""" Two alternatives for overlapping windows. + +`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold() + + if self.stride_tricks: + kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() + kv = kv.as_strided(( + B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), + stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) + else: + kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) + kv = kv.reshape( + B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3) +""" From b2094f4ee845d89aca8de65ae9b6ae09829a8b8e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 3 Oct 2021 17:31:22 -0700 Subject: [PATCH 07/33] support bits checkpoints in avg/load --- avg_checkpoints.py | 4 ++++ timm/models/helpers.py | 15 +++++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/avg_checkpoints.py b/avg_checkpoints.py index 1f7604b0..ea8bbe84 100755 --- a/avg_checkpoints.py +++ b/avg_checkpoints.py @@ -41,6 +41,10 @@ def checkpoint_metric(checkpoint_path): metric = None if 'metric' in checkpoint: metric = checkpoint['metric'] + elif 'metrics' in checkpoint and 'metric_name' in checkpoint: + metrics = checkpoint['metrics'] + print(metrics) + metric = metrics[checkpoint['metric_name']] return metric diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 662a7a48..bd97cf20 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -24,13 +24,20 @@ _logger = logging.getLogger(__name__) def load_state_dict(checkpoint_path, use_ema=False): if checkpoint_path and os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') - state_dict_key = 'state_dict' + state_dict_key = '' if isinstance(checkpoint, dict): - if use_ema and 'state_dict_ema' in checkpoint: + if use_ema and checkpoint.get('state_dict_ema', None) is not None: state_dict_key = 'state_dict_ema' - if state_dict_key and state_dict_key in checkpoint: + elif use_ema and checkpoint.get('model_ema', None) is not None: + state_dict_key = 'model_ema' + elif 'state_dict' in checkpoint: + state_dict_key = 'state_dict' + elif 'model' in checkpoint: + state_dict_key = 'model' + if state_dict_key: + state_dict = checkpoint[state_dict_key] new_state_dict = OrderedDict() - for k, v in checkpoint[state_dict_key].items(): + for k, v in state_dict.items(): # strip `module.` prefix name = k[7:] if k.startswith('module') else k new_state_dict[name] = v From 64495505b7bbf5438672d53804efaaa634bad710 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 3 Oct 2021 17:31:39 -0700 Subject: [PATCH 08/33] Add updated lambda resnet26 and botnet26 checkpoints with fixes applied --- timm/models/byoanet.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index f58b724c..3c43378a 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -34,8 +34,8 @@ def _cfg(url='', **kwargs): default_cfgs = { # GPU-Efficient (ResNet) weights 'botnet26t_256': _cfg( - url='', - fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth', + fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95), 'botnet50ts_256': _cfg( url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), @@ -58,13 +58,13 @@ default_cfgs = { input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'lambda_resnet26t': _cfg( - url='', - min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_a2h_256-25ded63d.pth', + min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95), 'lambda_resnet50ts': _cfg( url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), 'lambda_resnet26rpt_256': _cfg( - url='', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_a2h_256-482adad8.pth', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), } From cc9bedf373209664854dc400cbe5801e3fc1e6e9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 3 Oct 2021 17:32:02 -0700 Subject: [PATCH 09/33] Add initial ResNet Strikes Back weights for ResNet50 and ResNetV2-50 models --- timm/models/resnet.py | 4 ++-- timm/models/resnetv2.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index dad42f38..1f0716c5 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -53,11 +53,11 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet26t_256_ra2-6f6fa748.pth', interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)), 'resnet50': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-00ca2c6a.pth', interpolation='bicubic'), 'resnet50d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', - interpolation='bicubic', first_conv='conv1.0'), + interpolation='bicubic', first_conv='conv1.0', crop_pct=0.95), 'resnet50t': _cfg( url='', interpolation='bicubic', first_conv='conv1.0'), diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 2b5121a2..fe7fc466 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -105,7 +105,8 @@ default_cfgs = { input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic'), 'resnetv2_50': _cfg( - interpolation='bicubic'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_50_a1_h-000cdf49.pth', + interpolation='bicubic', crop_pct=0.95), 'resnetv2_50d': _cfg( interpolation='bicubic', first_conv='stem.conv1'), 'resnetv2_50t': _cfg( From da0d39bedd873c17d6cd2af50d78cbed564019c7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 3 Oct 2021 17:33:16 -0700 Subject: [PATCH 10/33] Update default crop_pct for byoanet --- timm/models/byoanet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 3c43378a..61f94490 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -23,7 +23,7 @@ __all__ = [] def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'crop_pct': 0.95, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc', 'fixed_input_size': False, 'min_input_size': (3, 224, 224), @@ -35,7 +35,7 @@ default_cfgs = { # GPU-Efficient (ResNet) weights 'botnet26t_256': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth', - fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95), + fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'botnet50ts_256': _cfg( url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), @@ -59,7 +59,7 @@ default_cfgs = { 'lambda_resnet26t': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_a2h_256-25ded63d.pth', - min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95), + min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), 'lambda_resnet50ts': _cfg( url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), From 93901e992f7bcb6bdb46729a307f67e39dd9b5fd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 3 Oct 2021 17:34:57 -0700 Subject: [PATCH 11/33] Version bump to 0.5.0 for pending release post RSB and ATTN updates --- timm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/version.py b/timm/version.py index 779b9fc3..2b8877c5 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.4.13' +__version__ = '0.5.0' From d123042605c238c136ff94613c0490448c22dc62 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 3 Oct 2021 21:38:47 -0700 Subject: [PATCH 12/33] Update README.md --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index fda37ca0..ba58c754 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,11 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### Oct 3, 2021 +* ResNet strikes back (https://arxiv.org/abs/2110.00476) weights added, plus any extra training components used. +* Attention model experiments are in as well (across byobnet.py/byoanet.py), along with weights. Details forthcoming. +* A lot more to add here... + ### Aug 18, 2021 * Optimizer bonanza! * Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/ `timm bits` [branch](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits)) From ae1ff5792fec2e5a0119ab0316a4222b538dfd51 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 4 Oct 2021 16:46:00 -0700 Subject: [PATCH 13/33] Clean a1/a2/3 rsb _0 checkpoints properly, fix v2 loading. --- clean_checkpoint.py | 14 +++----------- timm/models/resnet.py | 2 +- timm/models/resnetv2.py | 2 +- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/clean_checkpoint.py b/clean_checkpoint.py index a8edcc91..34e8604a 100755 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -13,6 +13,7 @@ import os import hashlib import shutil from collections import OrderedDict +from timm.models.helpers import load_state_dict parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', @@ -37,17 +38,8 @@ def main(): # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save if args.checkpoint and os.path.isfile(args.checkpoint): print("=> Loading checkpoint '{}'".format(args.checkpoint)) - checkpoint = torch.load(args.checkpoint, map_location='cpu') - - new_state_dict = OrderedDict() - if isinstance(checkpoint, dict): - state_dict_key = 'state_dict_ema' if args.use_ema else 'state_dict' - if state_dict_key in checkpoint: - state_dict = checkpoint[state_dict_key] - else: - state_dict = checkpoint - else: - assert False + state_dict = load_state_dict(args.checkpoint, use_ema=args.use_ema) + new_state_dict = {} for k, v in state_dict.items(): if args.clean_aux_bn and 'aux_bn' in k: # If all aux_bn keys are removed, the SplitBN layers will end up as normal and diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 1f0716c5..babb4c37 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -53,7 +53,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet26t_256_ra2-6f6fa748.pth', interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)), 'resnet50': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-00ca2c6a.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth', interpolation='bicubic'), 'resnet50d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index fe7fc466..b1344cb2 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -471,7 +471,7 @@ def _create_resnetv2(variant, pretrained=False, **kwargs): ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], feature_cfg=feature_cfg, - pretrained_custom_load=True, + pretrained_custom_load='_bit' in variant, **kwargs) From fbf59c04eeee5563c0f82b23c0643b511ecb2656 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 4 Oct 2021 22:31:08 -0700 Subject: [PATCH 14/33] Change crop ratio on correct resnet50 variant. --- timm/models/resnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index babb4c37..c1336458 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -54,10 +54,10 @@ default_cfgs = { interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)), 'resnet50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth', - interpolation='bicubic'), + interpolation='bicubic', crop_pct=0.95), 'resnet50d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', - interpolation='bicubic', first_conv='conv1.0', crop_pct=0.95), + interpolation='bicubic', first_conv='conv1.0'), 'resnet50t': _cfg( url='', interpolation='bicubic', first_conv='conv1.0'), From 431e60c83fb5db0991a5af55f9e9b635ecea9d8d Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 6 Oct 2021 14:28:49 +0100 Subject: [PATCH 15/33] Add acknowledgements for freeze_batch_norm inspiration --- timm/utils/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/timm/utils/model.py b/timm/utils/model.py index c2786401..ffe66049 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -108,6 +108,8 @@ def freeze_batch_norm_2d(module): Returns: torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 """ res = module if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): @@ -139,6 +141,8 @@ def unfreeze_batch_norm_2d(module): Returns: torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 """ res = module if isinstance(module, FrozenBatchNorm2d): From e0b3a3fab3db3abb5685ecfdf52cc76ba3027152 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 6 Oct 2021 16:12:05 -0700 Subject: [PATCH 16/33] Make test-pooling flag for validate.py opt in --- validate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/validate.py b/validate.py index 9b2c0f7e..2e18841f 100755 --- a/validate.py +++ b/validate.py @@ -80,8 +80,8 @@ parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') parser.add_argument('--num-gpu', type=int, default=1, help='Number of GPUS to use') -parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', - help='disable test time pool') +parser.add_argument('--test-pool', dest='test_pool', action='store_true', + help='enable test time pool') parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--pin-mem', action='store_true', default=False, @@ -154,7 +154,7 @@ def validate(args): data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True) test_time_pool = False - if not args.no_test_pool: + if args.test_pool: model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) if args.torchscript: From e2b8d44ff0220bf3ccb1b11c62f222975e80606f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 6 Oct 2021 16:29:33 -0700 Subject: [PATCH 17/33] Halo, bottleneck attn, lambda layer additions and cleanup along w/ experimental model defs * align interfaces of halo, bottleneck attn and lambda layer * add qk_ratio to all of above, control q/k dim relative to output dim * add experimental haloregnetz, and trionet (lambda + halo + bottle) models --- timm/models/byoanet.py | 61 +++++++++++++++++++++++++++ timm/models/byobnet.py | 10 ++--- timm/models/layers/bottleneck_attn.py | 60 ++++++++++++++++++-------- timm/models/layers/halo_attn.py | 59 ++++++++++++++++++++------ timm/models/layers/lambda_layer.py | 54 ++++++++++++++++-------- 5 files changed, 191 insertions(+), 53 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 61f94490..8c816f6e 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -66,6 +66,13 @@ default_cfgs = { 'lambda_resnet26rpt_256': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_a2h_256-482adad8.pth', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + + 'haloregnetz_b': _cfg( + url='', + input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94), + 'trionet50ts_256': _cfg( + url='', + fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), } @@ -232,6 +239,46 @@ model_cfgs = dict( self_attn_layer='lambda', self_attn_kwargs=dict(r=None) ), + + # experimental + haloregnetz_b=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3), + ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3), + interleave_blocks(types=('bottle', 'self_attn'), every=3, d=12, c=192, s=2, gs=16, br=3), + ByoBlockCfg('self_attn', d=2, c=288, s=2, gs=16, br=3), + ), + stem_chs=32, + stem_pool='', + downsample='', + num_features=1536, + act_layer='silu', + attn_layer='se', + attn_kwargs=dict(rd_ratio=0.25), + block_kwargs=dict(bottle_in=True, linear_out=True), + self_attn_layer='halo', + self_attn_kwargs=dict(block_size=7, halo_size=2, qk_ratio=0.33) + ), + + # experimental + trionet50ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + interleave_blocks( + types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25, + self_attn_layer='lambda', self_attn_kwargs=dict(r=13)), + interleave_blocks( + types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25, + self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)), + interleave_blocks( + types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25, + self_attn_layer='bottleneck', self_attn_kwargs=dict()), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='', + act_layer='silu', + ), ) @@ -327,3 +374,17 @@ def lambda_resnet26rpt_256(pretrained=False, **kwargs): """ kwargs.setdefault('img_size', 256) return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs) + + +@register_model +def haloregnetz_b(pretrained=False, **kwargs): + """ Halo + RegNetZ + """ + return _create_byoanet('haloregnetz_b', pretrained=pretrained, **kwargs) + + +@register_model +def trionet50ts_256(pretrained=False, **kwargs): + """ HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages + """ + return _create_byoanet('trionet50ts_256', 'trionet50ts', pretrained=pretrained, **kwargs) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 515f2073..4ac6ece3 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -1096,18 +1096,16 @@ class SelfAttnBlock(nn.Module): self.self_attn.reset_parameters() def forward(self, x): - shortcut = self.shortcut(x) - + shortcut = x x = self.conv1_1x1(x) x = self.conv2_kxk(x) x = self.self_attn(x) x = self.post_attn(x) x = self.conv3_1x1(x) x = self.drop_path(x) - - x = self.act(x + shortcut) - return x - + if self.shortcut is not None: + x = x + self.shortcut(shortcut) + return self.act(x) _block_registry = dict( basic=BasicBlock, diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index 61859f9c..15df62ae 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .helpers import to_2tuple +from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ @@ -66,10 +66,10 @@ class PosEmbedRel(nn.Module): self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale) def forward(self, q): - B, num_heads, HW, _ = q.shape + B, HW, _ = q.shape # relative logits in width dimension. - q = q.reshape(B * num_heads, self.height, self.width, -1) + q = q.reshape(B, self.height, self.width, -1) rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) # relative logits in height dimension. @@ -77,35 +77,56 @@ class PosEmbedRel(nn.Module): rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) rel_logits = rel_logits_h + rel_logits_w - rel_logits = rel_logits.reshape(B, num_heads, HW, HW) + rel_logits = rel_logits.reshape(B, HW, HW) return rel_logits class BottleneckAttn(nn.Module): """ Bottleneck Attention Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 + + The internal dimensions of the attention module are controlled by the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query and key (qk) dimensions are determined by + * num_heads * dim_head if dim_head is not None + * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + stride (int): output stride of the module, avg pool used if stride == 2 (default: 1). + num_heads (int): parallel attention heads (default: 4) + dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool): add bias to q, k, and v projections """ - def __init__(self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, qkv_bias=False): + def __init__( + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None, + qk_ratio=1.0, qkv_bias=False): super().__init__() assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required' dim_out = dim_out or dim assert dim_out % num_heads == 0 self.num_heads = num_heads - self.dim_out = dim_out - self.dim_head = dim_out // num_heads - self.scale = self.dim_head ** -0.5 + self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads + self.dim_head_v = dim_out // self.num_heads + self.dim_out_qk = num_heads * self.dim_head_qk + self.dim_out_v = num_heads * self.dim_head_v + self.scale = self.dim_head_qk ** -0.5 - self.qkv = nn.Conv2d(dim, self.dim_out * 3, 1, bias=qkv_bias) + self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias) # NOTE I'm only supporting relative pos embedding for now - self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head, scale=self.scale) + self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale) self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() self.reset_parameters() def reset_parameters(self): - trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in trunc_normal_(self.pos_embed.height_rel, std=self.scale) trunc_normal_(self.pos_embed.width_rel, std=self.scale) @@ -114,15 +135,20 @@ class BottleneckAttn(nn.Module): assert H == self.pos_embed.height assert W == self.pos_embed.width - x = self.qkv(x) # B, 3 * num_heads * dim_head, H, W - x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2) - q, k, v = torch.split(x, self.num_heads, dim=1) + x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W + + # NOTE head vs channel split ordering in qkv projection was decided before I allowed qk to differ from v + # So, this is more verbose than if heads were before qkv splits, but throughput is not impacted. + q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1) + q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2) + k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k + v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2) - attn = (q @ k.transpose(-1, -2)) * self.scale - attn = attn + self.pos_embed(q) # B, num_heads, H * W, H * W + attn = (q @ k) * self.scale + attn = attn + self.pos_embed(q) # B * num_heads, H * W, H * W attn = attn.softmax(dim=-1) - out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W + out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W out = self.pool(out) return out diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 034c66a8..05fb1f6a 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -22,6 +22,7 @@ import torch from torch import nn import torch.nn.functional as F +from .helpers import make_divisible from .weight_init import trunc_normal_ @@ -98,31 +99,62 @@ class HaloAttn(nn.Module): Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` - https://arxiv.org/abs/2103.12731 + + The internal dimensions of the attention module are controlled by the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query and key (qk) dimensions are determined by + * num_heads * dim_head if dim_head is not None + * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + feat_size (Tuple[int, int]): size of input feature_map (not used, for arg compat with bottle/lambda) + stride: output stride of the module, query downscaled if > 1 (default: 1). + num_heads: parallel attention heads (default: 8). + dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + block_size (int): size of blocks. (default: 8) + halo_size (int): size of halo overlap. (default: 3) + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool) : add bias to q, k, and v projections + avg_down (bool): use average pool downsample instead of strided query blocks + """ def __init__( - self, dim, dim_out=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, qkv_bias=False): + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, + qk_ratio=1.0, qkv_bias=False, avg_down=False): super().__init__() dim_out = dim_out or dim assert dim_out % num_heads == 0 - self.stride = stride + assert stride in (1, 2) self.num_heads = num_heads - self.dim_head_qk = dim_head or dim_out // num_heads + self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads self.dim_head_v = dim_out // self.num_heads self.dim_out_qk = num_heads * self.dim_head_qk self.dim_out_v = num_heads * self.dim_head_v - self.block_size = block_size + self.scale = self.dim_head_qk ** -0.5 + self.block_size = self.block_size_ds = block_size self.halo_size = halo_size self.win_size = block_size + halo_size * 2 # neighbourhood window size - self.scale = self.dim_head_qk ** -0.5 + self.block_stride = 1 + use_avg_pool = False + if stride > 1: + use_avg_pool = avg_down or block_size % stride != 0 + self.block_stride = 1 if use_avg_pool else stride + self.block_size_ds = self.block_size // self.block_stride # FIXME not clear if this stride behaviour is what the paper intended # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving # data in unfolded block form. I haven't wrapped my head around how that'd look. - self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.stride, bias=qkv_bias) + self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias) self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias) self.pos_embed = PosEmbedRel( - block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale) + block_size=self.block_size_ds, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale) + + self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity() self.reset_parameters() @@ -140,11 +172,12 @@ class HaloAttn(nn.Module): num_h_blocks = H // self.block_size num_w_blocks = W // self.block_size num_blocks = num_h_blocks * num_w_blocks - bs_stride = self.block_size // self.stride q = self.q(x) # unfold - q = q.reshape(-1, self.dim_head_qk, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4) + q = q.reshape( + -1, self.dim_head_qk, + num_h_blocks, self.block_size_ds, num_w_blocks, self.block_size_ds).permute(0, 1, 3, 5, 2, 4) # B, num_heads * dim_head * block_size ** 2, num_blocks q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3) # B * num_heads, num_blocks, block_size ** 2, dim_head @@ -163,9 +196,11 @@ class HaloAttn(nn.Module): out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks # fold - out = out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks) - out = out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_out_v, H // self.stride, W // self.stride) - # B, dim_out, H // stride, W // stride + out = out.reshape(-1, self.block_size_ds, self.block_size_ds, num_h_blocks, num_w_blocks) + out = out.permute(0, 3, 1, 4, 2).contiguous().view( + B, self.dim_out_v, H // self.block_stride, W // self.block_stride) + # B, dim_out, H // block_stride, W // block_stride + out = self.pool(out) return out diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py index eeb77e45..e50b43c8 100644 --- a/timm/models/layers/lambda_layer.py +++ b/timm/models/layers/lambda_layer.py @@ -24,7 +24,7 @@ import torch from torch import nn import torch.nn.functional as F -from .helpers import to_2tuple +from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ @@ -44,28 +44,46 @@ class LambdaLayer(nn.Module): - https://arxiv.org/abs/2102.08602 NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add. + + The internal dimensions of the lambda module are controlled via the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query (q) and key (k) dimension are determined by + * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None + * q = num_heads * dim_head, k = dim_head + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W + stride (int): output stride of the module, avg pool used if stride == 2 + num_heads (int): parallel attention heads. + dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9) + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool): add bias to q, k, and v projections """ def __init__( - self, - dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9, + qk_ratio=1.0, qkv_bias=False): super().__init__() - self.dim = dim - self.dim_out = dim_out or dim - self.dim_k = dim_head # query depth 'k' + dim_out = dim_out or dim + assert dim_out % num_heads == 0, ' should be divided by num_heads' + self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads self.num_heads = num_heads - assert self.dim_out % num_heads == 0, ' should be divided by num_heads' - self.dim_v = self.dim_out // num_heads # value depth 'v' + self.dim_v = dim_out // num_heads self.qkv = nn.Conv2d( dim, - num_heads * dim_head + dim_head + self.dim_v, + num_heads * self.dim_qk + self.dim_qk + self.dim_v, kernel_size=1, bias=qkv_bias) - self.norm_q = nn.BatchNorm2d(num_heads * dim_head) + self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk) self.norm_v = nn.BatchNorm2d(self.dim_v) if r is not None: # local lambda convolution for pos - self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0)) + self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0)) self.pos_emb = None self.rel_pos_indices = None else: @@ -74,7 +92,7 @@ class LambdaLayer(nn.Module): feat_size = to_2tuple(feat_size) rel_size = [2 * s - 1 for s in feat_size] self.conv_lambda = None - self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_k)) + self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk)) self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False) self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() @@ -82,9 +100,9 @@ class LambdaLayer(nn.Module): self.reset_parameters() def reset_parameters(self): - trunc_normal_(self.qkv.weight, std=self.dim ** -0.5) + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in if self.conv_lambda is not None: - trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5) + trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5) if self.pos_emb is not None: trunc_normal_(self.pos_emb, std=.02) @@ -93,17 +111,17 @@ class LambdaLayer(nn.Module): M = H * W qkv = self.qkv(x) q, k, v = torch.split(qkv, [ - self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1) - q = self.norm_q(q).reshape(B, self.num_heads, self.dim_k, M).transpose(-1, -2) # B, num_heads, M, K + self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1) + q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V - k = F.softmax(k.reshape(B, self.dim_k, M), dim=-1) # B, K, M + k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M content_lam = k @ v # B, K, V content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V if self.pos_emb is None: position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K - position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V + position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V else: # FIXME relative pos embedding path not fully verified pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1) From e5da481073ac4beb634ee9b33e264baa3bee8688 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 6 Oct 2021 17:00:27 -0700 Subject: [PATCH 18/33] Small post-merge tweak for freeze/unfreeze, add to __init__ for utils --- timm/utils/__init__.py | 2 +- timm/utils/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index d02e62d2..11de9c9c 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -7,7 +7,7 @@ from .jit import set_jit_legacy from .log import setup_default_logging, FormatterNoInfo from .metrics import AverageMeter, accuracy from .misc import natural_key, add_bool_arg -from .model import unwrap_model, get_state_dict +from .model import unwrap_model, get_state_dict, freeze, unfreeze from .model_ema import ModelEma, ModelEmaV2 from .random import random_seed from .summary import update_summary, get_outdir diff --git a/timm/utils/model.py b/timm/utils/model.py index ffe66049..879ac3f8 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -194,7 +194,7 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, for n, m in zip(named_modules, submodules): # (Un)freeze parameters for p in m.parameters(): - p.requires_grad = (False if mode == 'freeze' else True) + p.requires_grad = False if mode == 'freeze' else True if include_bn_running_stats: # Helper to add submodule specified as a named_module def _add_submodule(module, name, submodule): From b544ad4d3fcd02057ab9f43b118290f2a089566f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 6 Oct 2021 21:14:59 -0700 Subject: [PATCH 19/33] regnetz model default cfg tweaks --- timm/models/byoanet.py | 2 +- timm/models/byobnet.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 8c816f6e..8b629dc4 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -69,7 +69,7 @@ default_cfgs = { 'haloregnetz_b': _cfg( url='', - input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94), + first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94), 'trionet50ts_256': _cfg( url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 4ac6ece3..4363709f 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -137,12 +137,12 @@ default_cfgs = { # experimental models, likely to change ot be removed 'regnetz_b': _cfgr( - url='', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_b_raa-677d9606.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 224, 224), pool_size=(7, 7), first_conv='stem.conv'), + input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.95), 'regnetz_c': _cfgr( - url='', - imean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), first_conv='stem.conv'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_c_rab_256-6bdb3c01.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.95), 'regnetz_d': _cfgr( url='', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), From 38804c721b45f92f8139def38e2224a98c66eb0d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 8 Oct 2021 17:43:53 -0700 Subject: [PATCH 20/33] Checkpoint clean fn useable stand alone --- clean_checkpoint.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/clean_checkpoint.py b/clean_checkpoint.py index 34e8604a..3eea15e6 100755 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -20,7 +20,7 @@ parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--output', default='', type=str, metavar='PATH', help='output path') -parser.add_argument('--use-ema', dest='use_ema', action='store_true', +parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', help='use ema version of weights if present') parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true', help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint') @@ -35,19 +35,23 @@ def main(): print("Error: Output filename ({}) already exists.".format(args.output)) exit(1) + clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn) + + +def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False): # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save - if args.checkpoint and os.path.isfile(args.checkpoint): - print("=> Loading checkpoint '{}'".format(args.checkpoint)) - state_dict = load_state_dict(args.checkpoint, use_ema=args.use_ema) + if checkpoint and os.path.isfile(checkpoint): + print("=> Loading checkpoint '{}'".format(checkpoint)) + state_dict = load_state_dict(checkpoint, use_ema=use_ema) new_state_dict = {} for k, v in state_dict.items(): - if args.clean_aux_bn and 'aux_bn' in k: + if clean_aux_bn and 'aux_bn' in k: # If all aux_bn keys are removed, the SplitBN layers will end up as normal and # load with the unmodified model using BatchNorm2d. continue name = k[7:] if k.startswith('module') else k new_state_dict[name] = v - print("=> Loaded state_dict from '{}'".format(args.checkpoint)) + print("=> Loaded state_dict from '{}'".format(checkpoint)) try: torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False) @@ -57,17 +61,19 @@ def main(): with open(_TEMP_NAME, 'rb') as f: sha_hash = hashlib.sha256(f.read()).hexdigest() - if args.output: - checkpoint_root, checkpoint_base = os.path.split(args.output) + if output: + checkpoint_root, checkpoint_base = os.path.split(output) checkpoint_base = os.path.splitext(checkpoint_base)[0] else: checkpoint_root = '' - checkpoint_base = os.path.splitext(args.checkpoint)[0] + checkpoint_base = os.path.splitext(checkpoint)[0] final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth' shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename)) print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash)) + return final_filename else: - print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint)) + print("Error: Checkpoint ({}) doesn't exist".format(checkpoint)) + return '' if __name__ == '__main__': From a85df349936ae21c2a33c54a1b4a6522f0b0c9d0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 8 Oct 2021 17:44:13 -0700 Subject: [PATCH 21/33] Update lambda_resnet26rpt weights to 78.9, add better halonet26t weights at 79.1 with tweak to attention dim --- timm/models/byoanet.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 8b629dc4..54c7081d 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -45,8 +45,8 @@ default_cfgs = { 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet26t': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_256-9b4bf0b3.pth', - input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_a1h_256-3083328c.pth', + input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'sehalonet33ts': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), @@ -64,8 +64,8 @@ default_cfgs = { url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), 'lambda_resnet26rpt_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_a2h_256-482adad8.pth', - fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_c_256-ab00292d.pth', + fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94), 'haloregnetz_b': _cfg( url='', @@ -149,7 +149,7 @@ model_cfgs = dict( stem_type='tiered', stem_pool='maxpool', self_attn_layer='halo', - self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16) + self_attn_kwargs=dict(block_size=8, halo_size=2) ), sehalonet33ts=ByoModelCfg( blocks=( From 44d6d51668ce53ba92ea746cfe29b759f688920a Mon Sep 17 00:00:00 2001 From: ICLR Author Date: Sat, 9 Oct 2021 21:09:51 -0400 Subject: [PATCH 22/33] Add ConvMixer --- timm/models/__init__.py | 1 + timm/models/convmixer.py | 101 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+) create mode 100644 timm/models/convmixer.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 56a753b1..0982b6e1 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -4,6 +4,7 @@ from .byobnet import * from .cait import * from .coat import * from .convit import * +from .convmixer import * from .crossvit import * from .cspnet import * from .densenet import * diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py new file mode 100644 index 00000000..a2400782 --- /dev/null +++ b/timm/models/convmixer.py @@ -0,0 +1,101 @@ +import torch.nn as nn +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.registry import register_model +from .helpers import build_model_with_cfg + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .96, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', + 'first_conv': 'stem.0', + **kwargs + } + + +default_cfgs = { + 'convmixer_1536_20': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1536_20_ks9_p7.pth.tar'), + 'convmixer_768_32': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_768_32_ks7_p7_relu.pth.tar'), + 'convmixer_1024_20_ks9_p14': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1024_20_ks9_p14.pth.tar') +} + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + x + + +class ConvMixer(nn.Module): + def __init__(self, dim, depth, kernel_size=9, patch_size=7, in_chans=3, num_classes=1000, activation=nn.GELU, **kwargs): + super().__init__() + self.num_classes = num_classes + self.num_features = dim + self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity() + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size), + activation(), + nn.BatchNorm2d(dim) + ) + self.blocks = nn.Sequential( + *[nn.Sequential( + Residual(nn.Sequential( + nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), + activation(), + nn.BatchNorm2d(dim) + )), + nn.Conv2d(dim, dim, kernel_size=1), + activation(), + nn.BatchNorm2d(dim) + ) for i in range(depth)] + ) + self.pooling = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten() + ) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + x = self.blocks(x) + x = self.pooling(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + + return x + + +def _create_convmixer(variant, pretrained=False, **kwargs): + return build_model_with_cfg(ConvMixer, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs) + + +@register_model +def convmixer_1536_20(pretrained=False, **kwargs): + model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs) + return _create_convmixer('convmixer_1536_20', pretrained, **model_args) + + +@register_model +def convmixer_768_32(pretrained=False, **kwargs): + model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, activation=nn.ReLU, **kwargs) + return _create_convmixer('convmixer_768_32', pretrained, **model_args) + + +@register_model +def convmixer_1024_20_ks9_p14(pretrained=False, **kwargs): + model_args = dict(dim=1024, depth=20, kernel_size=9, patch_size=14, **kwargs) + return _create_convmixer('convmixer_1024_20_ks9_p14', pretrained, **model_args) \ No newline at end of file From 6ed4cdccca23e14de502f1f5b7087eb976238679 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 10 Oct 2021 16:32:54 -0700 Subject: [PATCH 23/33] Update lambda_resnet26t weights with better set --- timm/models/byoanet.py | 4 ++-- timm/models/resnet.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 54c7081d..313af3e2 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -58,8 +58,8 @@ default_cfgs = { input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'lambda_resnet26t': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_a2h_256-25ded63d.pth', - min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_c_256-e5a5c857.pth', + min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94), 'lambda_resnet50ts': _cfg( url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), diff --git a/timm/models/resnet.py b/timm/models/resnet.py index c1336458..bca1de46 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -51,7 +51,7 @@ default_cfgs = { interpolation='bicubic', first_conv='conv1.0'), 'resnet26t': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet26t_256_ra2-6f6fa748.pth', - interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)), + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94), 'resnet50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth', interpolation='bicubic', crop_pct=0.95), From cd34913278f8511ba53492ed186b4e08f890add6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 11 Oct 2021 22:43:41 -0700 Subject: [PATCH 24/33] Remove some outdated comments, botnet networks working great now. --- timm/models/byoanet.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 313af3e2..d296d4ba 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -294,7 +294,6 @@ def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs): @register_model def botnet26t_256(pretrained=False, **kwargs): """ Bottleneck Transformer w/ ResNet26-T backbone. - NOTE: this isn't performing well, may remove """ kwargs.setdefault('img_size', 256) return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs) @@ -303,7 +302,6 @@ def botnet26t_256(pretrained=False, **kwargs): @register_model def botnet50ts_256(pretrained=False, **kwargs): """ Bottleneck Transformer w/ ResNet50-T backbone, silu act. - NOTE: this isn't performing well, may remove """ kwargs.setdefault('img_size', 256) return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs) @@ -312,7 +310,6 @@ def botnet50ts_256(pretrained=False, **kwargs): @register_model def eca_botnext26ts_256(pretrained=False, **kwargs): """ Bottleneck Transformer w/ ResNet26-T backbone, silu act. - NOTE: this isn't performing well, may remove """ kwargs.setdefault('img_size', 256) return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs) @@ -385,6 +382,6 @@ def haloregnetz_b(pretrained=False, **kwargs): @register_model def trionet50ts_256(pretrained=False, **kwargs): - """ HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages + """ TrioNet """ return _create_byoanet('trionet50ts_256', 'trionet50ts', pretrained=pretrained, **kwargs) From 047a5ec05f267604c3d5f33cb575316aad94e94c Mon Sep 17 00:00:00 2001 From: masafumi Date: Tue, 12 Oct 2021 23:51:46 +0900 Subject: [PATCH 25/33] Fix bugs that Mixup does not work device=cpu --- timm/data/mixup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/data/mixup.py b/timm/data/mixup.py index 38477548..7e382c52 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -214,7 +214,7 @@ class Mixup: lam = self._mix_pair(x) else: lam = self._mix_batch(x) - target = mixup_target(target, self.num_classes, lam, self.label_smoothing) + target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) return x, target From 02daf2ab943ce2c1646c4af65026114facf4eb22 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 12 Oct 2021 15:35:43 -0700 Subject: [PATCH 26/33] Add option to include relative pos embedding in the attention scaling as per references. See discussion #912 --- timm/models/layers/bottleneck_attn.py | 15 +++++++++------ timm/models/layers/halo_attn.py | 17 ++++++++++------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index 15df62ae..f55fd989 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -61,9 +61,8 @@ class PosEmbedRel(nn.Module): super().__init__() self.height, self.width = to_2tuple(feat_size) self.dim_head = dim_head - self.scale = scale - self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * self.scale) - self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale) + self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * scale) + self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * scale) def forward(self, q): B, HW, _ = q.shape @@ -101,10 +100,11 @@ class BottleneckAttn(nn.Module): dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) qkv_bias (bool): add bias to q, k, and v projections + scale_pos_embed (bool): scale the position embedding as well as Q @ K """ def __init__( self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None, - qk_ratio=1.0, qkv_bias=False): + qk_ratio=1.0, qkv_bias=False, scale_pos_embed=False): super().__init__() assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required' dim_out = dim_out or dim @@ -115,6 +115,7 @@ class BottleneckAttn(nn.Module): self.dim_out_qk = num_heads * self.dim_head_qk self.dim_out_v = num_heads * self.dim_head_v self.scale = self.dim_head_qk ** -0.5 + self.scale_pos_embed = scale_pos_embed self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias) @@ -144,8 +145,10 @@ class BottleneckAttn(nn.Module): k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2) - attn = (q @ k) * self.scale - attn = attn + self.pos_embed(q) # B * num_heads, H * W, H * W + if self.scale_pos_embed: + attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W + else: + attn = (q @ k) * self.scale + self.pos_embed(q) attn = attn.softmax(dim=-1) out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 05fb1f6a..846c12ff 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -74,9 +74,8 @@ class PosEmbedRel(nn.Module): super().__init__() self.block_size = block_size self.dim_head = dim_head - self.scale = scale - self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale) - self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale) + self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale) + self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale) def forward(self, q): B, BB, HW, _ = q.shape @@ -120,11 +119,11 @@ class HaloAttn(nn.Module): qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) qkv_bias (bool) : add bias to q, k, and v projections avg_down (bool): use average pool downsample instead of strided query blocks - + scale_pos_embed (bool): scale the position embedding as well as Q @ K """ def __init__( self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, - qk_ratio=1.0, qkv_bias=False, avg_down=False): + qk_ratio=1.0, qkv_bias=False, avg_down=False, scale_pos_embed=False): super().__init__() dim_out = dim_out or dim assert dim_out % num_heads == 0 @@ -135,6 +134,7 @@ class HaloAttn(nn.Module): self.dim_out_qk = num_heads * self.dim_head_qk self.dim_out_v = num_heads * self.dim_head_v self.scale = self.dim_head_qk ** -0.5 + self.scale_pos_embed = scale_pos_embed self.block_size = self.block_size_ds = block_size self.halo_size = halo_size self.win_size = block_size + halo_size * 2 # neighbourhood window size @@ -190,8 +190,11 @@ class HaloAttn(nn.Module): k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1) # B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v - attn = (q @ k.transpose(-1, -2)) * self.scale - attn = attn + self.pos_embed(q) # B * num_heads, num_blocks, block_size ** 2, win_size ** 2 + if self.scale_pos_embed: + attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale + else: + attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q) + # B * num_heads, num_blocks, block_size ** 2, win_size ** 2 attn = attn.softmax(dim=-1) out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks From c02334d9fad88e391ca120f08fa54d42ba74003e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 19 Oct 2021 12:32:09 -0700 Subject: [PATCH 27/33] Add weights for regnetz_d and haloregnetz_c, update regnetz_c weights. Add commented PyTorch XLA code for halo attention --- timm/models/byoanet.py | 3 ++- timm/models/byobnet.py | 10 +++++----- timm/models/layers/halo_attn.py | 15 ++++++++++++--- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index d296d4ba..c7a5c53e 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -68,7 +68,8 @@ default_cfgs = { fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94), 'haloregnetz_b': _cfg( - url='', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/haloregnetz_c_raa_256-c8ad7616.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94), 'trionet50ts_256': _cfg( url='', diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 4363709f..93898209 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -139,13 +139,13 @@ default_cfgs = { 'regnetz_b': _cfgr( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_b_raa-677d9606.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.95), + input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.94), 'regnetz_c': _cfgr( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_c_rab_256-6bdb3c01.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.95), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_c_rab2_256-a54bf36a.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.94), 'regnetz_d': _cfgr( - url='', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d_rab_256-b8073a89.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), } diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 846c12ff..4149e812 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -183,7 +183,9 @@ class HaloAttn(nn.Module): # B * num_heads, num_blocks, block_size ** 2, dim_head kv = self.kv(x) - # generate overlapping windows for kv + # Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not + # lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach. + # FIXME figure out how to switch impl between this and conv2d if XLA being used. kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]) kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape( B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1) @@ -207,17 +209,24 @@ class HaloAttn(nn.Module): return out -""" Two alternatives for overlapping windows. +""" Three alternatives for overlapping windows. `.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold() - if self.stride_tricks: + if is_xla: + # This code achieves haloing on PyTorch XLA with reasonable runtime trade-off, it is + # EXTREMELY slow for backward on a GPU though so I need a way of selecting based on environment. + WW = self.win_size ** 2 + pw = torch.eye(WW, dtype=x.dtype, device=x.device).reshape(WW, 1, self.win_size, self.win_size) + kv = F.conv2d(kv.reshape(-1, 1, H, W), pw, stride=self.block_size, padding=self.halo_size) + elif self.stride_tricks: kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() kv = kv.as_strided(( B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) else: kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) + kv = kv.reshape( B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3) """ From b6caa356d2dcac2c02bc1c81b9d3ffbae3fc50ad Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 19 Oct 2021 12:44:28 -0700 Subject: [PATCH 28/33] Fixed eca_botnext26ts_256 weights added, 79.27 --- timm/models/byoanet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index c7a5c53e..dfcba46f 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -40,7 +40,7 @@ default_cfgs = { url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'eca_botnext26ts_256': _cfg( - url='', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_botnext26ts_c_256-95a898f6.pth', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), @@ -122,7 +122,7 @@ model_cfgs = dict( act_layer='silu', attn_layer='eca', self_attn_layer='bottleneck', - self_attn_kwargs=dict() + self_attn_kwargs=dict(dim_head=16) ), halonet_h1=ByoModelCfg( From 0ba73e6bcb6cfb9ed2aabd5d90e659f491a61302 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 19 Oct 2021 14:38:56 -0700 Subject: [PATCH 29/33] Update README.md --- README.md | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index ba58c754..13b0d587 100644 --- a/README.md +++ b/README.md @@ -19,14 +19,21 @@ In addition to the sponsors at the link above, I've received hardware and/or clo * Nvidia (https://www.nvidia.com/en-us/) * TFRC (https://www.tensorflow.org/tfrc) -I'm fortunate to be able to dedicate significant time and money of my own supporting this and other open source projects. However, as the projects increase in scope, outside support is needed to continue with the current trajectory of hardware, infrastructure, and electricty costs. +I'm fortunate to be able to dedicate significant time and money of my own supporting this and other open source projects. However, as the projects increase in scope, outside support is needed to continue with the current trajectory of cloud services, hardware, and electricity costs. ## What's New -### Oct 3, 2021 -* ResNet strikes back (https://arxiv.org/abs/2110.00476) weights added, plus any extra training components used. -* Attention model experiments are in as well (across byobnet.py/byoanet.py), along with weights. Details forthcoming. -* A lot more to add here... +### Oct 19, 2021 +* ResNet strikes back (https://arxiv.org/abs/2110.00476) weights added, plus any extra training components used. Model weights and some more details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-rsb-weights) +* BCE loss and Repeated Augmentation support for RSB paper +* 4 series of ResNet based attention model experiments being added (implemented across byobnet.py/byoanet.py). These include all sorts of attention, from channel attn like SE, ECA to 2D QKV self-attention layers such as Halo, Bottlneck, Lambda. Details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights) +* Working implementations of the following 2D self-attention modules (likely to be differences from paper or eventual official impl): + * Halo (https://arxiv.org/abs/2103.12731) + * Bottleneck Transformer (https://arxiv.org/abs/2101.11605) + * LambdaNetworks (https://arxiv.org/abs/2102.08602) +* A RegNetZ series of models with some attention experiments (being added to). These do not follow the paper (https://arxiv.org/abs/2103.06877) in any way other than block architecture, details of official models are not available. See more here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights) +* ConvMixer (https://openreview.net/forum?id=TVHS5Y4dNvM), CrossVit (https://arxiv.org/abs/2103.14899), and BeiT (https://arxiv.org/abs/2106.08254) architectures + weights added +* freeze/unfreeze helpers by [Alexander Soare](https://github.com/alexander-soare) ### Aug 18, 2021 * Optimizer bonanza! From 13a8bf79720e6657f4043f15d69d271872753bb2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 19 Oct 2021 15:14:26 -0700 Subject: [PATCH 30/33] Add train size override and deepspeed GMACs counter (if deepspeed installed) to benchmark.py --- benchmark.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/benchmark.py b/benchmark.py index 903bb817..98f2ef84 100755 --- a/benchmark.py +++ b/benchmark.py @@ -18,6 +18,11 @@ from collections import OrderedDict from contextlib import suppress from functools import partial +try: + from deepspeed.profiling.flops_profiler import get_model_profile +except ImportError as e: + get_model_profile = None + from timm.models import create_model, is_model, list_models from timm.optim import create_optimizer_v2 from timm.data import resolve_data_config @@ -67,6 +72,8 @@ parser.add_argument('--img-size', default=None, type=int, metavar='N', help='Input image dimension, uses model default if empty') parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') +parser.add_argument('--use-train-size', action='store_true', default=False, + help='Run inference at train size, not test-input-size if it exists.') parser.add_argument('--num-classes', type=int, default=None, help='Number classes in dataset') parser.add_argument('--gp', default=None, type=str, metavar='POOL', @@ -81,6 +88,7 @@ parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') + # train optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', help='Optimizer (default: "sgd"') @@ -139,10 +147,25 @@ def resolve_precision(precision: str): return use_amp, model_dtype, data_dtype +def profile(model, input_size=(3, 224, 224)): + batch_size = 1 + macs, params = get_model_profile( + model=model, + input_res=(batch_size,) + input_size, # input shape or input to the input_constructor + input_constructor=None, # if specified, a constructor taking input_res is used as input to the model + print_profile=False, # prints the model graph with the measured profile attached to each module + detailed=False, # print the detailed profile + warm_up=10, # the number of warm-ups before measuring the time of each module + as_string=False, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k) + output_file=None, # path to the output file. If None, the profiler prints to stdout. + ignore_modules=None) # the list of modules to ignore in the profiling + return macs + + class BenchmarkRunner: def __init__( self, model_name, detail=False, device='cuda', torchscript=False, precision='float32', - num_warm_iter=10, num_bench_iter=50, **kwargs): + num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): self.model_name = model_name self.detail = detail self.device = device @@ -166,7 +189,7 @@ class BenchmarkRunner: if torchscript: self.model = torch.jit.script(self.model) - data_config = resolve_data_config(kwargs, model=self.model, use_test_size=True) + data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) self.input_size = data_config['input_size'] self.batch_size = kwargs.pop('batch_size', 256) @@ -234,6 +257,10 @@ class InferenceBenchmarkRunner(BenchmarkRunner): param_count=round(self.param_count / 1e6, 2), ) + if get_model_profile is not None: + macs = profile(self.model, self.input_size) + results['GMACs'] = round(macs / 1e9, 2) + _logger.info( f"Inference benchmark of {self.model_name} done. " f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/step") From 66253790d42e41064be9e53421e8b91dccbc890f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 19 Oct 2021 16:06:38 -0700 Subject: [PATCH 31/33] Add `--bench profile` mode for benchmark.py to just run deepspeed detailed profile on model --- benchmark.py | 47 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/benchmark.py b/benchmark.py index 98f2ef84..61bae0d4 100755 --- a/benchmark.py +++ b/benchmark.py @@ -147,19 +147,19 @@ def resolve_precision(precision: str): return use_amp, model_dtype, data_dtype -def profile(model, input_size=(3, 224, 224)): +def profile(model, input_size=(3, 224, 224), detailed=False): batch_size = 1 macs, params = get_model_profile( model=model, input_res=(batch_size,) + input_size, # input shape or input to the input_constructor input_constructor=None, # if specified, a constructor taking input_res is used as input to the model - print_profile=False, # prints the model graph with the measured profile attached to each module - detailed=False, # print the detailed profile + print_profile=detailed, # prints the model graph with the measured profile attached to each module + detailed=detailed, # print the detailed profile warm_up=10, # the number of warm-ups before measuring the time of each module as_string=False, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k) output_file=None, # path to the output file. If None, the profiler prints to stdout. ignore_modules=None) # the list of modules to ignore in the profiling - return macs + return macs, params class BenchmarkRunner: @@ -258,8 +258,8 @@ class InferenceBenchmarkRunner(BenchmarkRunner): ) if get_model_profile is not None: - macs = profile(self.model, self.input_size) - results['GMACs'] = round(macs / 1e9, 2) + macs, _ = profile(self.model, self.input_size) + results['gmacs'] = round(macs / 1e9, 2) _logger.info( f"Inference benchmark of {self.model_name} done. " @@ -388,6 +388,32 @@ class TrainBenchmarkRunner(BenchmarkRunner): return results +class ProfileRunner(BenchmarkRunner): + + def __init__(self, model_name, device='cuda', **kwargs): + super().__init__(model_name=model_name, device=device, **kwargs) + self.model.eval() + + def run(self): + _logger.info( + f'Running profiler on {self.model_name} w/ ' + f'input size {self.input_size} and batch size 1.') + + macs, params = profile(self.model, self.input_size, detailed=True) + + results = dict( + gmacs=round(macs / 1e9, 2), + img_size=self.input_size[-1], + param_count=round(params / 1e6, 2), + ) + + _logger.info( + f"Profile of {self.model_name} done. " + f"{results['gmacs']:.2f} GMACs, {results['param_count']:.2f} M params.") + + return results + + def decay_batch_exp(batch_size, factor=0.5, divisor=16): out_batch_size = batch_size * factor if out_batch_size > divisor: @@ -436,6 +462,9 @@ def benchmark(args): elif args.bench == 'train': bench_fns = TrainBenchmarkRunner, prefixes = 'train', + elif args.bench == 'profile': + assert get_model_profile is not None, "deepspeed needs to be installed for profile" + bench_fns = ProfileRunner, model_results = OrderedDict(model=model) for prefix, bench_fn in zip(prefixes, bench_fns): @@ -483,7 +512,11 @@ def main(): results.append(r) except KeyboardInterrupt as e: pass - sort_key = 'train_samples_per_sec' if 'train' in args.bench else 'infer_samples_per_sec' + sort_key = 'infer_samples_per_sec' + if 'train' in args.bench: + sort_key = 'train_samples_per_sec' + elif 'profile' in args.bench: + sort_key = 'infer_gmacs' results = sorted(results, key=lambda x: x[sort_key], reverse=True) if len(results): write_results(results_file, results) From f7325c7b712100f79a9ab4ae54118d259c11bacf Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 20 Oct 2021 15:17:30 -0700 Subject: [PATCH 32/33] Support either deepspeed or fvcore for flop profiling --- benchmark.py | 81 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 59 insertions(+), 22 deletions(-) diff --git a/benchmark.py b/benchmark.py index 61bae0d4..477a0391 100755 --- a/benchmark.py +++ b/benchmark.py @@ -18,11 +18,6 @@ from collections import OrderedDict from contextlib import suppress from functools import partial -try: - from deepspeed.profiling.flops_profiler import get_model_profile -except ImportError as e: - get_model_profile = None - from timm.models import create_model, is_model, list_models from timm.optim import create_optimizer_v2 from timm.data import resolve_data_config @@ -43,6 +38,20 @@ try: except AttributeError: pass +try: + from deepspeed.profiling.flops_profiler import get_model_profile + has_deepspeed_profiling = True +except ImportError as e: + has_deepspeed_profiling = False + +try: + from fvcore.nn import FlopCountAnalysis, flop_count_str + has_fvcore_profiling = True +except ImportError as e: + FlopCountAnalysis = None + has_fvcore_profiling = False + + torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') @@ -147,9 +156,8 @@ def resolve_precision(precision: str): return use_amp, model_dtype, data_dtype -def profile(model, input_size=(3, 224, 224), detailed=False): - batch_size = 1 - macs, params = get_model_profile( +def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False): + macs, _ = get_model_profile( model=model, input_res=(batch_size,) + input_size, # input shape or input to the input_constructor input_constructor=None, # if specified, a constructor taking input_res is used as input to the model @@ -159,7 +167,16 @@ def profile(model, input_size=(3, 224, 224), detailed=False): as_string=False, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k) output_file=None, # path to the output file. If None, the profiler prints to stdout. ignore_modules=None) # the list of modules to ignore in the profiling - return macs, params + return macs + + +def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False): + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + fca = FlopCountAnalysis(model, torch.ones((batch_size,) + input_size, device=device, dtype=dtype)) + if detailed: + fcs = flop_count_str(fca) + print(fcs) + return fca.total() class BenchmarkRunner: @@ -257,8 +274,11 @@ class InferenceBenchmarkRunner(BenchmarkRunner): param_count=round(self.param_count / 1e6, 2), ) - if get_model_profile is not None: - macs, _ = profile(self.model, self.input_size) + if has_deepspeed_profiling: + macs = profile_deepspeed(self.model, self.input_size) + results['gmacs'] = round(macs / 1e9, 2) + elif has_fvcore_profiling: + macs = profile_fvcore(self.model, self.input_size) results['gmacs'] = round(macs / 1e9, 2) _logger.info( @@ -390,21 +410,33 @@ class TrainBenchmarkRunner(BenchmarkRunner): class ProfileRunner(BenchmarkRunner): - def __init__(self, model_name, device='cuda', **kwargs): + def __init__(self, model_name, device='cuda', profiler='', **kwargs): super().__init__(model_name=model_name, device=device, **kwargs) + if not profiler: + if has_deepspeed_profiling: + profiler = 'deepspeed' + elif has_fvcore_profiling: + profiler = 'fvcore' + assert profiler, "One of deepspeed or fvcore needs to be installed for profiling to work." + self.profiler = profiler self.model.eval() def run(self): _logger.info( f'Running profiler on {self.model_name} w/ ' - f'input size {self.input_size} and batch size 1.') + f'input size {self.input_size} and batch size {self.batch_size}.') - macs, params = profile(self.model, self.input_size, detailed=True) + macs = 0 + if self.profiler == 'deepspeed': + macs = profile_deepspeed(self.model, self.input_size, batch_size=self.batch_size, detailed=True) + elif self.profiler == 'fvcore': + macs = profile_fvcore(self.model, self.input_size, batch_size=self.batch_size, detailed=True) results = dict( gmacs=round(macs / 1e9, 2), + batch_size=self.batch_size, img_size=self.input_size[-1], - param_count=round(params / 1e6, 2), + param_count=round(self.param_count / 1e6, 2), ) _logger.info( @@ -462,9 +494,16 @@ def benchmark(args): elif args.bench == 'train': bench_fns = TrainBenchmarkRunner, prefixes = 'train', - elif args.bench == 'profile': - assert get_model_profile is not None, "deepspeed needs to be installed for profile" + elif args.bench.startswith('profile'): + # specific profiler used if included in bench mode string, otherwise default to deepspeed, fallback to fvcore + if 'deepspeed' in args.bench: + assert has_deepspeed_profiling, "deepspeed must be installed to use deepspeed flop counter" + bench_kwargs['profiler'] = 'deepspeed' + elif 'fvcore' in args.bench: + assert has_fvcore_profiling, "fvcore must be installed to use fvcore flop counter" + bench_kwargs['profiler'] = 'fvcore' bench_fns = ProfileRunner, + batch_size = 1 model_results = OrderedDict(model=model) for prefix, bench_fn in zip(prefixes, bench_fns): @@ -520,12 +559,10 @@ def main(): results = sorted(results, key=lambda x: x[sort_key], reverse=True) if len(results): write_results(results_file, results) - - import json - json_str = json.dumps(results, indent=4) - print(json_str) else: - benchmark(args) + results = benchmark(args) + json_str = json.dumps(results, indent=4) + print(json_str) def write_results(results_file, results): From 25e7c8c5e548f2063ffe8d83659dc4eea1d249cd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 20 Oct 2021 22:14:12 -0700 Subject: [PATCH 33/33] Update broken resnetv2_50 weight url, add resnetv1_101 a1h recipe weights for 224x224 train --- timm/models/resnetv2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index b1344cb2..43940cc3 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -105,14 +105,15 @@ default_cfgs = { input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic'), 'resnetv2_50': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_50_a1_h-000cdf49.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_50_a1h-000cdf49.pth', interpolation='bicubic', crop_pct=0.95), 'resnetv2_50d': _cfg( interpolation='bicubic', first_conv='stem.conv1'), 'resnetv2_50t': _cfg( interpolation='bicubic', first_conv='stem.conv1'), 'resnetv2_101': _cfg( - interpolation='bicubic'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_101_a1h-5d01f016.pth', + interpolation='bicubic', crop_pct=0.95), 'resnetv2_101d': _cfg( interpolation='bicubic', first_conv='stem.conv1'), 'resnetv2_152': _cfg(