diff --git a/README.md b/README.md index fda37ca0..13b0d587 100644 --- a/README.md +++ b/README.md @@ -19,10 +19,22 @@ In addition to the sponsors at the link above, I've received hardware and/or clo * Nvidia (https://www.nvidia.com/en-us/) * TFRC (https://www.tensorflow.org/tfrc) -I'm fortunate to be able to dedicate significant time and money of my own supporting this and other open source projects. However, as the projects increase in scope, outside support is needed to continue with the current trajectory of hardware, infrastructure, and electricty costs. +I'm fortunate to be able to dedicate significant time and money of my own supporting this and other open source projects. However, as the projects increase in scope, outside support is needed to continue with the current trajectory of cloud services, hardware, and electricity costs. ## What's New +### Oct 19, 2021 +* ResNet strikes back (https://arxiv.org/abs/2110.00476) weights added, plus any extra training components used. Model weights and some more details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-rsb-weights) +* BCE loss and Repeated Augmentation support for RSB paper +* 4 series of ResNet based attention model experiments being added (implemented across byobnet.py/byoanet.py). These include all sorts of attention, from channel attn like SE, ECA to 2D QKV self-attention layers such as Halo, Bottlneck, Lambda. Details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights) +* Working implementations of the following 2D self-attention modules (likely to be differences from paper or eventual official impl): + * Halo (https://arxiv.org/abs/2103.12731) + * Bottleneck Transformer (https://arxiv.org/abs/2101.11605) + * LambdaNetworks (https://arxiv.org/abs/2102.08602) +* A RegNetZ series of models with some attention experiments (being added to). These do not follow the paper (https://arxiv.org/abs/2103.06877) in any way other than block architecture, details of official models are not available. See more here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights) +* ConvMixer (https://openreview.net/forum?id=TVHS5Y4dNvM), CrossVit (https://arxiv.org/abs/2103.14899), and BeiT (https://arxiv.org/abs/2106.08254) architectures + weights added +* freeze/unfreeze helpers by [Alexander Soare](https://github.com/alexander-soare) + ### Aug 18, 2021 * Optimizer bonanza! * Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/ `timm bits` [branch](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits)) diff --git a/benchmark.py b/benchmark.py index 903bb817..477a0391 100755 --- a/benchmark.py +++ b/benchmark.py @@ -38,6 +38,20 @@ try: except AttributeError: pass +try: + from deepspeed.profiling.flops_profiler import get_model_profile + has_deepspeed_profiling = True +except ImportError as e: + has_deepspeed_profiling = False + +try: + from fvcore.nn import FlopCountAnalysis, flop_count_str + has_fvcore_profiling = True +except ImportError as e: + FlopCountAnalysis = None + has_fvcore_profiling = False + + torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') @@ -67,6 +81,8 @@ parser.add_argument('--img-size', default=None, type=int, metavar='N', help='Input image dimension, uses model default if empty') parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') +parser.add_argument('--use-train-size', action='store_true', default=False, + help='Run inference at train size, not test-input-size if it exists.') parser.add_argument('--num-classes', type=int, default=None, help='Number classes in dataset') parser.add_argument('--gp', default=None, type=str, metavar='POOL', @@ -81,6 +97,7 @@ parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') + # train optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', help='Optimizer (default: "sgd"') @@ -139,10 +156,33 @@ def resolve_precision(precision: str): return use_amp, model_dtype, data_dtype +def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False): + macs, _ = get_model_profile( + model=model, + input_res=(batch_size,) + input_size, # input shape or input to the input_constructor + input_constructor=None, # if specified, a constructor taking input_res is used as input to the model + print_profile=detailed, # prints the model graph with the measured profile attached to each module + detailed=detailed, # print the detailed profile + warm_up=10, # the number of warm-ups before measuring the time of each module + as_string=False, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k) + output_file=None, # path to the output file. If None, the profiler prints to stdout. + ignore_modules=None) # the list of modules to ignore in the profiling + return macs + + +def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False): + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + fca = FlopCountAnalysis(model, torch.ones((batch_size,) + input_size, device=device, dtype=dtype)) + if detailed: + fcs = flop_count_str(fca) + print(fcs) + return fca.total() + + class BenchmarkRunner: def __init__( self, model_name, detail=False, device='cuda', torchscript=False, precision='float32', - num_warm_iter=10, num_bench_iter=50, **kwargs): + num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): self.model_name = model_name self.detail = detail self.device = device @@ -166,7 +206,7 @@ class BenchmarkRunner: if torchscript: self.model = torch.jit.script(self.model) - data_config = resolve_data_config(kwargs, model=self.model, use_test_size=True) + data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) self.input_size = data_config['input_size'] self.batch_size = kwargs.pop('batch_size', 256) @@ -234,6 +274,13 @@ class InferenceBenchmarkRunner(BenchmarkRunner): param_count=round(self.param_count / 1e6, 2), ) + if has_deepspeed_profiling: + macs = profile_deepspeed(self.model, self.input_size) + results['gmacs'] = round(macs / 1e9, 2) + elif has_fvcore_profiling: + macs = profile_fvcore(self.model, self.input_size) + results['gmacs'] = round(macs / 1e9, 2) + _logger.info( f"Inference benchmark of {self.model_name} done. " f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/step") @@ -361,6 +408,44 @@ class TrainBenchmarkRunner(BenchmarkRunner): return results +class ProfileRunner(BenchmarkRunner): + + def __init__(self, model_name, device='cuda', profiler='', **kwargs): + super().__init__(model_name=model_name, device=device, **kwargs) + if not profiler: + if has_deepspeed_profiling: + profiler = 'deepspeed' + elif has_fvcore_profiling: + profiler = 'fvcore' + assert profiler, "One of deepspeed or fvcore needs to be installed for profiling to work." + self.profiler = profiler + self.model.eval() + + def run(self): + _logger.info( + f'Running profiler on {self.model_name} w/ ' + f'input size {self.input_size} and batch size {self.batch_size}.') + + macs = 0 + if self.profiler == 'deepspeed': + macs = profile_deepspeed(self.model, self.input_size, batch_size=self.batch_size, detailed=True) + elif self.profiler == 'fvcore': + macs = profile_fvcore(self.model, self.input_size, batch_size=self.batch_size, detailed=True) + + results = dict( + gmacs=round(macs / 1e9, 2), + batch_size=self.batch_size, + img_size=self.input_size[-1], + param_count=round(self.param_count / 1e6, 2), + ) + + _logger.info( + f"Profile of {self.model_name} done. " + f"{results['gmacs']:.2f} GMACs, {results['param_count']:.2f} M params.") + + return results + + def decay_batch_exp(batch_size, factor=0.5, divisor=16): out_batch_size = batch_size * factor if out_batch_size > divisor: @@ -409,6 +494,16 @@ def benchmark(args): elif args.bench == 'train': bench_fns = TrainBenchmarkRunner, prefixes = 'train', + elif args.bench.startswith('profile'): + # specific profiler used if included in bench mode string, otherwise default to deepspeed, fallback to fvcore + if 'deepspeed' in args.bench: + assert has_deepspeed_profiling, "deepspeed must be installed to use deepspeed flop counter" + bench_kwargs['profiler'] = 'deepspeed' + elif 'fvcore' in args.bench: + assert has_fvcore_profiling, "fvcore must be installed to use fvcore flop counter" + bench_kwargs['profiler'] = 'fvcore' + bench_fns = ProfileRunner, + batch_size = 1 model_results = OrderedDict(model=model) for prefix, bench_fn in zip(prefixes, bench_fns): @@ -456,16 +551,18 @@ def main(): results.append(r) except KeyboardInterrupt as e: pass - sort_key = 'train_samples_per_sec' if 'train' in args.bench else 'infer_samples_per_sec' + sort_key = 'infer_samples_per_sec' + if 'train' in args.bench: + sort_key = 'train_samples_per_sec' + elif 'profile' in args.bench: + sort_key = 'infer_gmacs' results = sorted(results, key=lambda x: x[sort_key], reverse=True) if len(results): write_results(results_file, results) - - import json - json_str = json.dumps(results, indent=4) - print(json_str) else: - benchmark(args) + results = benchmark(args) + json_str = json.dumps(results, indent=4) + print(json_str) def write_results(results_file, results): diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..b0f890d2 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,57 @@ +from torch.nn.modules.batchnorm import BatchNorm2d +from torchvision.ops.misc import FrozenBatchNorm2d + +import timm +from timm.utils.model import freeze, unfreeze + + +def test_freeze_unfreeze(): + model = timm.create_model('resnet18') + + # Freeze all + freeze(model) + # Check top level module + assert model.fc.weight.requires_grad == False + # Check submodule + assert model.layer1[0].conv1.weight.requires_grad == False + # Check BN + assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) + + # Unfreeze all + unfreeze(model) + # Check top level module + assert model.fc.weight.requires_grad == True + # Check submodule + assert model.layer1[0].conv1.weight.requires_grad == True + # Check BN + assert isinstance(model.layer1[0].bn1, BatchNorm2d) + + # Freeze some + freeze(model, ['layer1', 'layer2.0']) + # Check frozen + assert model.layer1[0].conv1.weight.requires_grad == False + assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) + assert model.layer2[0].conv1.weight.requires_grad == False + # Check not frozen + assert model.layer3[0].conv1.weight.requires_grad == True + assert isinstance(model.layer3[0].bn1, BatchNorm2d) + assert model.layer2[1].conv1.weight.requires_grad == True + + # Unfreeze some + unfreeze(model, ['layer1', 'layer2.0']) + # Check not frozen + assert model.layer1[0].conv1.weight.requires_grad == True + assert isinstance(model.layer1[0].bn1, BatchNorm2d) + assert model.layer2[0].conv1.weight.requires_grad == True + + # Freeze/unfreeze BN + # From root + freeze(model, ['layer1.0.bn1']) + assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) + unfreeze(model, ['layer1.0.bn1']) + assert isinstance(model.layer1[0].bn1, BatchNorm2d) + # From direct parent + freeze(model.layer1[0], ['bn1']) + assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) + unfreeze(model.layer1[0], ['bn1']) + assert isinstance(model.layer1[0].bn1, BatchNorm2d) \ No newline at end of file diff --git a/timm/data/mixup.py b/timm/data/mixup.py index b618bb7c..074b6941 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -216,7 +216,7 @@ class Mixup: lam = self._mix_pair(x) else: lam = self._mix_batch(x) - target = mixup_target(target, self.num_classes, lam, self.label_smoothing) + target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) return x, target diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 56a753b1..0982b6e1 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -4,6 +4,7 @@ from .byobnet import * from .cait import * from .coat import * from .convit import * +from .convmixer import * from .crossvit import * from .cspnet import * from .densenet import * diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 056813ef..dfcba46f 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -3,7 +3,7 @@ A flexible network w/ dataclass based config for stacking NN blocks including self-attention (or similar) layers. -Currently used to implement experimential variants of: +Currently used to implement experimental variants of: * Bottleneck Transformers * Lambda ResNets * HaloNets @@ -23,7 +23,7 @@ __all__ = [] def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'crop_pct': 0.95, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc', 'fixed_input_size': False, 'min_input_size': (3, 224, 224), @@ -34,35 +34,44 @@ def _cfg(url='', **kwargs): default_cfgs = { # GPU-Efficient (ResNet) weights 'botnet26t_256': _cfg( - url='', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'botnet50ts_256': _cfg( url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'eca_botnext26ts_256': _cfg( - url='', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_botnext26ts_c_256-95a898f6.pth', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet26t': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_256-9b4bf0b3.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_a1h_256-3083328c.pth', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'sehalonet33ts': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'halonet50ts': _cfg( - url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_256_ra3-f07eab9f.pth', + input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'eca_halonext26ts': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_256-1e55880b.pth', - input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'lambda_resnet26t': _cfg( - url='', - min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_c_256-e5a5c857.pth', + min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94), 'lambda_resnet50ts': _cfg( url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), 'lambda_resnet26rpt_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_c_256-ab00292d.pth', + fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94), + + 'haloregnetz_b': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/haloregnetz_c_raa_256-c8ad7616.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94), + 'trionet50ts_256': _cfg( url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), } @@ -113,7 +122,7 @@ model_cfgs = dict( act_layer='silu', attn_layer='eca', self_attn_layer='bottleneck', - self_attn_kwargs=dict() + self_attn_kwargs=dict(dim_head=16) ), halonet_h1=ByoModelCfg( @@ -141,7 +150,7 @@ model_cfgs = dict( stem_type='tiered', stem_pool='maxpool', self_attn_layer='halo', - self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16) + self_attn_kwargs=dict(block_size=8, halo_size=2) ), sehalonet33ts=ByoModelCfg( blocks=( @@ -231,6 +240,46 @@ model_cfgs = dict( self_attn_layer='lambda', self_attn_kwargs=dict(r=None) ), + + # experimental + haloregnetz_b=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3), + ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3), + interleave_blocks(types=('bottle', 'self_attn'), every=3, d=12, c=192, s=2, gs=16, br=3), + ByoBlockCfg('self_attn', d=2, c=288, s=2, gs=16, br=3), + ), + stem_chs=32, + stem_pool='', + downsample='', + num_features=1536, + act_layer='silu', + attn_layer='se', + attn_kwargs=dict(rd_ratio=0.25), + block_kwargs=dict(bottle_in=True, linear_out=True), + self_attn_layer='halo', + self_attn_kwargs=dict(block_size=7, halo_size=2, qk_ratio=0.33) + ), + + # experimental + trionet50ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + interleave_blocks( + types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25, + self_attn_layer='lambda', self_attn_kwargs=dict(r=13)), + interleave_blocks( + types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25, + self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)), + interleave_blocks( + types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25, + self_attn_layer='bottleneck', self_attn_kwargs=dict()), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='', + act_layer='silu', + ), ) @@ -246,7 +295,6 @@ def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs): @register_model def botnet26t_256(pretrained=False, **kwargs): """ Bottleneck Transformer w/ ResNet26-T backbone. - NOTE: this isn't performing well, may remove """ kwargs.setdefault('img_size', 256) return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs) @@ -255,7 +303,6 @@ def botnet26t_256(pretrained=False, **kwargs): @register_model def botnet50ts_256(pretrained=False, **kwargs): """ Bottleneck Transformer w/ ResNet50-T backbone, silu act. - NOTE: this isn't performing well, may remove """ kwargs.setdefault('img_size', 256) return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs) @@ -264,7 +311,6 @@ def botnet50ts_256(pretrained=False, **kwargs): @register_model def eca_botnext26ts_256(pretrained=False, **kwargs): """ Bottleneck Transformer w/ ResNet26-T backbone, silu act. - NOTE: this isn't performing well, may remove """ kwargs.setdefault('img_size', 256) return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs) @@ -326,3 +372,17 @@ def lambda_resnet26rpt_256(pretrained=False, **kwargs): """ kwargs.setdefault('img_size', 256) return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs) + + +@register_model +def haloregnetz_b(pretrained=False, **kwargs): + """ Halo + RegNetZ + """ + return _create_byoanet('haloregnetz_b', pretrained=pretrained, **kwargs) + + +@register_model +def trionet50ts_256(pretrained=False, **kwargs): + """ TrioNet + """ + return _create_byoanet('trionet50ts_256', 'trionet50ts', pretrained=pretrained, **kwargs) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 515f2073..93898209 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -137,15 +137,15 @@ default_cfgs = { # experimental models, likely to change ot be removed 'regnetz_b': _cfgr( - url='', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_b_raa-677d9606.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 224, 224), pool_size=(7, 7), first_conv='stem.conv'), + input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.94), 'regnetz_c': _cfgr( - url='', - imean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), first_conv='stem.conv'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_c_rab2_256-a54bf36a.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.94), 'regnetz_d': _cfgr( - url='', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d_rab_256-b8073a89.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), } @@ -1096,18 +1096,16 @@ class SelfAttnBlock(nn.Module): self.self_attn.reset_parameters() def forward(self, x): - shortcut = self.shortcut(x) - + shortcut = x x = self.conv1_1x1(x) x = self.conv2_kxk(x) x = self.self_attn(x) x = self.post_attn(x) x = self.conv3_1x1(x) x = self.drop_path(x) - - x = self.act(x + shortcut) - return x - + if self.shortcut is not None: + x = x + self.shortcut(shortcut) + return self.act(x) _block_registry = dict( basic=BasicBlock, diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py new file mode 100644 index 00000000..a2400782 --- /dev/null +++ b/timm/models/convmixer.py @@ -0,0 +1,101 @@ +import torch.nn as nn +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.registry import register_model +from .helpers import build_model_with_cfg + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .96, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', + 'first_conv': 'stem.0', + **kwargs + } + + +default_cfgs = { + 'convmixer_1536_20': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1536_20_ks9_p7.pth.tar'), + 'convmixer_768_32': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_768_32_ks7_p7_relu.pth.tar'), + 'convmixer_1024_20_ks9_p14': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1024_20_ks9_p14.pth.tar') +} + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + x + + +class ConvMixer(nn.Module): + def __init__(self, dim, depth, kernel_size=9, patch_size=7, in_chans=3, num_classes=1000, activation=nn.GELU, **kwargs): + super().__init__() + self.num_classes = num_classes + self.num_features = dim + self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity() + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size), + activation(), + nn.BatchNorm2d(dim) + ) + self.blocks = nn.Sequential( + *[nn.Sequential( + Residual(nn.Sequential( + nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), + activation(), + nn.BatchNorm2d(dim) + )), + nn.Conv2d(dim, dim, kernel_size=1), + activation(), + nn.BatchNorm2d(dim) + ) for i in range(depth)] + ) + self.pooling = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten() + ) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + x = self.blocks(x) + x = self.pooling(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + + return x + + +def _create_convmixer(variant, pretrained=False, **kwargs): + return build_model_with_cfg(ConvMixer, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs) + + +@register_model +def convmixer_1536_20(pretrained=False, **kwargs): + model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs) + return _create_convmixer('convmixer_1536_20', pretrained, **model_args) + + +@register_model +def convmixer_768_32(pretrained=False, **kwargs): + model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, activation=nn.ReLU, **kwargs) + return _create_convmixer('convmixer_768_32', pretrained, **model_args) + + +@register_model +def convmixer_1024_20_ks9_p14(pretrained=False, **kwargs): + model_args = dict(dim=1024, depth=20, kernel_size=9, patch_size=14, **kwargs) + return _create_convmixer('convmixer_1024_20_ks9_p14', pretrained, **model_args) \ No newline at end of file diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index bf6af675..f55fd989 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .helpers import to_2tuple +from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ @@ -61,15 +61,14 @@ class PosEmbedRel(nn.Module): super().__init__() self.height, self.width = to_2tuple(feat_size) self.dim_head = dim_head - self.scale = scale - self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * self.scale) - self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale) + self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * scale) + self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * scale) def forward(self, q): - B, num_heads, HW, _ = q.shape + B, HW, _ = q.shape # relative logits in width dimension. - q = q.reshape(B * num_heads, self.height, self.width, -1) + q = q.reshape(B, self.height, self.width, -1) rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) # relative logits in height dimension. @@ -77,35 +76,58 @@ class PosEmbedRel(nn.Module): rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) rel_logits = rel_logits_h + rel_logits_w - rel_logits = rel_logits.reshape(B, num_heads, HW, HW) + rel_logits = rel_logits.reshape(B, HW, HW) return rel_logits class BottleneckAttn(nn.Module): """ Bottleneck Attention Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 + + The internal dimensions of the attention module are controlled by the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query and key (qk) dimensions are determined by + * num_heads * dim_head if dim_head is not None + * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + stride (int): output stride of the module, avg pool used if stride == 2 (default: 1). + num_heads (int): parallel attention heads (default: 4) + dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool): add bias to q, k, and v projections + scale_pos_embed (bool): scale the position embedding as well as Q @ K """ - def __init__(self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, qkv_bias=False): + def __init__( + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None, + qk_ratio=1.0, qkv_bias=False, scale_pos_embed=False): super().__init__() assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required' dim_out = dim_out or dim assert dim_out % num_heads == 0 self.num_heads = num_heads - self.dim_out = dim_out - self.dim_head = dim_out // num_heads - self.scale = self.dim_head ** -0.5 + self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads + self.dim_head_v = dim_out // self.num_heads + self.dim_out_qk = num_heads * self.dim_head_qk + self.dim_out_v = num_heads * self.dim_head_v + self.scale = self.dim_head_qk ** -0.5 + self.scale_pos_embed = scale_pos_embed - self.qkv = nn.Conv2d(dim, self.dim_out * 3, 1, bias=qkv_bias) + self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias) # NOTE I'm only supporting relative pos embedding for now - self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head, scale=self.scale) + self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale) self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() self.reset_parameters() def reset_parameters(self): - trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in trunc_normal_(self.pos_embed.height_rel, std=self.scale) trunc_normal_(self.pos_embed.width_rel, std=self.scale) @@ -114,16 +136,23 @@ class BottleneckAttn(nn.Module): assert H == self.pos_embed.height assert W == self.pos_embed.width - x = self.qkv(x) # B, 3 * num_heads * dim_head, H, W - x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2) - q, k, v = torch.split(x, self.num_heads, dim=1) + x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W + + # NOTE head vs channel split ordering in qkv projection was decided before I allowed qk to differ from v + # So, this is more verbose than if heads were before qkv splits, but throughput is not impacted. + q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1) + q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2) + k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k + v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2) - attn_logits = (q @ k.transpose(-1, -2)) * self.scale - attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W + if self.scale_pos_embed: + attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W + else: + attn = (q @ k) * self.scale + self.pos_embed(q) + attn = attn.softmax(dim=-1) - attn_out = attn_logits.softmax(dim=-1) - attn_out = (attn_out @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W - attn_out = self.pool(attn_out) - return attn_out + out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W + out = self.pool(out) + return out diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index d298fc0b..4149e812 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -22,6 +22,7 @@ import torch from torch import nn import torch.nn.functional as F +from .helpers import make_divisible from .weight_init import trunc_normal_ @@ -73,9 +74,8 @@ class PosEmbedRel(nn.Module): super().__init__() self.block_size = block_size self.dim_head = dim_head - self.scale = scale - self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale) - self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale) + self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale) + self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale) def forward(self, q): B, BB, HW, _ = q.shape @@ -98,30 +98,63 @@ class HaloAttn(nn.Module): Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` - https://arxiv.org/abs/2103.12731 + + The internal dimensions of the attention module are controlled by the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query and key (qk) dimensions are determined by + * num_heads * dim_head if dim_head is not None + * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + feat_size (Tuple[int, int]): size of input feature_map (not used, for arg compat with bottle/lambda) + stride: output stride of the module, query downscaled if > 1 (default: 1). + num_heads: parallel attention heads (default: 8). + dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + block_size (int): size of blocks. (default: 8) + halo_size (int): size of halo overlap. (default: 3) + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool) : add bias to q, k, and v projections + avg_down (bool): use average pool downsample instead of strided query blocks + scale_pos_embed (bool): scale the position embedding as well as Q @ K """ def __init__( - self, dim, dim_out=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, qkv_bias=False): + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, + qk_ratio=1.0, qkv_bias=False, avg_down=False, scale_pos_embed=False): super().__init__() dim_out = dim_out or dim assert dim_out % num_heads == 0 - self.stride = stride + assert stride in (1, 2) self.num_heads = num_heads - self.dim_head = dim_head or dim // num_heads - self.dim_qk = num_heads * self.dim_head - self.dim_v = dim_out - self.block_size = block_size + self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads + self.dim_head_v = dim_out // self.num_heads + self.dim_out_qk = num_heads * self.dim_head_qk + self.dim_out_v = num_heads * self.dim_head_v + self.scale = self.dim_head_qk ** -0.5 + self.scale_pos_embed = scale_pos_embed + self.block_size = self.block_size_ds = block_size self.halo_size = halo_size self.win_size = block_size + halo_size * 2 # neighbourhood window size - self.scale = self.dim_head ** -0.5 + self.block_stride = 1 + use_avg_pool = False + if stride > 1: + use_avg_pool = avg_down or block_size % stride != 0 + self.block_stride = 1 if use_avg_pool else stride + self.block_size_ds = self.block_size // self.block_stride # FIXME not clear if this stride behaviour is what the paper intended # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving # data in unfolded block form. I haven't wrapped my head around how that'd look. - self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias) - self.kv = nn.Conv2d(dim, self.dim_qk + self.dim_v, 1, bias=qkv_bias) + self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias) + self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias) self.pos_embed = PosEmbedRel( - block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale) + block_size=self.block_size_ds, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale) + + self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity() self.reset_parameters() @@ -139,41 +172,61 @@ class HaloAttn(nn.Module): num_h_blocks = H // self.block_size num_w_blocks = W // self.block_size num_blocks = num_h_blocks * num_w_blocks - bs_stride = self.block_size // self.stride q = self.q(x) # unfold - q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4) + q = q.reshape( + -1, self.dim_head_qk, + num_h_blocks, self.block_size_ds, num_w_blocks, self.block_size_ds).permute(0, 1, 3, 5, 2, 4) # B, num_heads * dim_head * block_size ** 2, num_blocks - q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3) + q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3) # B * num_heads, num_blocks, block_size ** 2, dim_head kv = self.kv(x) - # generate overlapping windows for kv + # Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not + # lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach. + # FIXME figure out how to switch impl between this and conv2d if XLA being used. kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]) kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape( - B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), num_blocks, -1).permute(0, 2, 3, 1) - # NOTE these two alternatives are equivalent, but above is the best balance of performance and clarity - # if self.stride_tricks: - # kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() - # kv = kv.as_strided(( - # B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), - # stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) - # else: - # kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) - # kv = kv.reshape( - # B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3) - k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1) - # B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads - - attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied? - attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2 - - attn_out = attn_logits.softmax(dim=-1) - attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks - + B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1) + k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1) + # B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v + + if self.scale_pos_embed: + attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale + else: + attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q) + # B * num_heads, num_blocks, block_size ** 2, win_size ** 2 + attn = attn.softmax(dim=-1) + + out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks # fold - attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks) - attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride) - # B, dim_out, H // stride, W // stride - return attn_out + out = out.reshape(-1, self.block_size_ds, self.block_size_ds, num_h_blocks, num_w_blocks) + out = out.permute(0, 3, 1, 4, 2).contiguous().view( + B, self.dim_out_v, H // self.block_stride, W // self.block_stride) + # B, dim_out, H // block_stride, W // block_stride + out = self.pool(out) + return out + + +""" Three alternatives for overlapping windows. + +`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold() + + if is_xla: + # This code achieves haloing on PyTorch XLA with reasonable runtime trade-off, it is + # EXTREMELY slow for backward on a GPU though so I need a way of selecting based on environment. + WW = self.win_size ** 2 + pw = torch.eye(WW, dtype=x.dtype, device=x.device).reshape(WW, 1, self.win_size, self.win_size) + kv = F.conv2d(kv.reshape(-1, 1, H, W), pw, stride=self.block_size, padding=self.halo_size) + elif self.stride_tricks: + kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() + kv = kv.as_strided(( + B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), + stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) + else: + kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) + + kv = kv.reshape( + B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3) +""" diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py index eeb77e45..e50b43c8 100644 --- a/timm/models/layers/lambda_layer.py +++ b/timm/models/layers/lambda_layer.py @@ -24,7 +24,7 @@ import torch from torch import nn import torch.nn.functional as F -from .helpers import to_2tuple +from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ @@ -44,28 +44,46 @@ class LambdaLayer(nn.Module): - https://arxiv.org/abs/2102.08602 NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add. + + The internal dimensions of the lambda module are controlled via the interaction of several arguments. + * the output dimension of the module is specified by dim_out, which falls back to input dim if not set + * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim + * the query (q) and key (k) dimension are determined by + * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None + * q = num_heads * dim_head, k = dim_head + * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set + + Args: + dim (int): input dimension to the module + dim_out (int): output dimension of the module, same as dim if not set + feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W + stride (int): output stride of the module, avg pool used if stride == 2 + num_heads (int): parallel attention heads. + dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9) + qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias (bool): add bias to q, k, and v projections """ def __init__( - self, - dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): + self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9, + qk_ratio=1.0, qkv_bias=False): super().__init__() - self.dim = dim - self.dim_out = dim_out or dim - self.dim_k = dim_head # query depth 'k' + dim_out = dim_out or dim + assert dim_out % num_heads == 0, ' should be divided by num_heads' + self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads self.num_heads = num_heads - assert self.dim_out % num_heads == 0, ' should be divided by num_heads' - self.dim_v = self.dim_out // num_heads # value depth 'v' + self.dim_v = dim_out // num_heads self.qkv = nn.Conv2d( dim, - num_heads * dim_head + dim_head + self.dim_v, + num_heads * self.dim_qk + self.dim_qk + self.dim_v, kernel_size=1, bias=qkv_bias) - self.norm_q = nn.BatchNorm2d(num_heads * dim_head) + self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk) self.norm_v = nn.BatchNorm2d(self.dim_v) if r is not None: # local lambda convolution for pos - self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0)) + self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0)) self.pos_emb = None self.rel_pos_indices = None else: @@ -74,7 +92,7 @@ class LambdaLayer(nn.Module): feat_size = to_2tuple(feat_size) rel_size = [2 * s - 1 for s in feat_size] self.conv_lambda = None - self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_k)) + self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk)) self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False) self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() @@ -82,9 +100,9 @@ class LambdaLayer(nn.Module): self.reset_parameters() def reset_parameters(self): - trunc_normal_(self.qkv.weight, std=self.dim ** -0.5) + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in if self.conv_lambda is not None: - trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5) + trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5) if self.pos_emb is not None: trunc_normal_(self.pos_emb, std=.02) @@ -93,17 +111,17 @@ class LambdaLayer(nn.Module): M = H * W qkv = self.qkv(x) q, k, v = torch.split(qkv, [ - self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1) - q = self.norm_q(q).reshape(B, self.num_heads, self.dim_k, M).transpose(-1, -2) # B, num_heads, M, K + self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1) + q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V - k = F.softmax(k.reshape(B, self.dim_k, M), dim=-1) # B, K, M + k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M content_lam = k @ v # B, K, V content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V if self.pos_emb is None: position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K - position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V + position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V else: # FIXME relative pos embedding path not fully verified pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index dad42f38..bca1de46 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -51,10 +51,10 @@ default_cfgs = { interpolation='bicubic', first_conv='conv1.0'), 'resnet26t': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet26t_256_ra2-6f6fa748.pth', - interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)), + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94), 'resnet50': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth', - interpolation='bicubic'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth', + interpolation='bicubic', crop_pct=0.95), 'resnet50d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', interpolation='bicubic', first_conv='conv1.0'), diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 2b5121a2..43940cc3 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -105,13 +105,15 @@ default_cfgs = { input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic'), 'resnetv2_50': _cfg( - interpolation='bicubic'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_50_a1h-000cdf49.pth', + interpolation='bicubic', crop_pct=0.95), 'resnetv2_50d': _cfg( interpolation='bicubic', first_conv='stem.conv1'), 'resnetv2_50t': _cfg( interpolation='bicubic', first_conv='stem.conv1'), 'resnetv2_101': _cfg( - interpolation='bicubic'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_101_a1h-5d01f016.pth', + interpolation='bicubic', crop_pct=0.95), 'resnetv2_101d': _cfg( interpolation='bicubic', first_conv='stem.conv1'), 'resnetv2_152': _cfg( @@ -470,7 +472,7 @@ def _create_resnetv2(variant, pretrained=False, **kwargs): ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], feature_cfg=feature_cfg, - pretrained_custom_load=True, + pretrained_custom_load='_bit' in variant, **kwargs) diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index d02e62d2..11de9c9c 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -7,7 +7,7 @@ from .jit import set_jit_legacy from .log import setup_default_logging, FormatterNoInfo from .metrics import AverageMeter, accuracy from .misc import natural_key, add_bool_arg -from .model import unwrap_model, get_state_dict +from .model import unwrap_model, get_state_dict, freeze, unfreeze from .model_ema import ModelEma, ModelEmaV2 from .random import random_seed from .summary import update_summary, get_outdir diff --git a/timm/utils/model.py b/timm/utils/model.py index 66f7480e..b9f3e9d3 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -2,39 +2,38 @@ Hacked together by / Copyright 2020 Ross Wightman """ -from .model_ema import ModelEma -import torch import fnmatch -_SUB_MODULE_ATTR = ('module', 'model') +import torch +from torchvision.ops.misc import FrozenBatchNorm2d + +from .model_ema import ModelEma -def unwrap_model(model, recursive=True): - for attr in _SUB_MODULE_ATTR: - sub_module = getattr(model, attr, None) - if sub_module is not None: - return unwrap_model(sub_module) if recursive else sub_module - return model +def unwrap_model(model): + if isinstance(model, ModelEma): + return unwrap_model(model.ema) + else: + return model.module if hasattr(model, 'module') else model def get_state_dict(model, unwrap_fn=unwrap_model): return unwrap_fn(model).state_dict() -def avg_sq_ch_mean(model, input, output): - """calculate average channel square mean of output activations - """ - return torch.mean(output.mean(axis=[0, 2, 3]) ** 2).item() +def avg_sq_ch_mean(model, input, output): + "calculate average channel square mean of output activations" + return torch.mean(output.mean(axis=[0,2,3])**2).item() -def avg_ch_var(model, input, output): - """calculate average channel variance of output activations""" - return torch.mean(output.var(axis=[0, 2, 3])).item() +def avg_ch_var(model, input, output): + "calculate average channel variance of output activations" + return torch.mean(output.var(axis=[0,2,3])).item()\ -def avg_ch_var_residual(model, input, output): - """calculate average channel variance of output activations""" - return torch.mean(output.var(axis=[0, 2, 3])).item() +def avg_ch_var_residual(model, input, output): + "calculate average channel variance of output activations" + return torch.mean(output.var(axis=[0,2,3])).item() class ActivationStatsHook: @@ -63,16 +62,15 @@ class ActivationStatsHook: raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \ their lengths are different.") self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns) - for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns): + for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns): self.register_hook(hook_fn_loc, hook_fn) def _create_hook(self, hook_fn): def append_activation_stats(module, input, output): out = hook_fn(module, input, output) self.stats[hook_fn.__name__].append(out) - return append_activation_stats - + def register_hook(self, hook_fn_loc, hook_fn): for name, module in self.model.named_modules(): if not fnmatch.fnmatch(name, hook_fn_loc): @@ -80,9 +78,9 @@ class ActivationStatsHook: module.register_forward_hook(self._create_hook(hook_fn)) -def extract_spp_stats(model, +def extract_spp_stats(model, hook_fn_locs, - hook_fns, + hook_fns, input_shape=[8, 3, 224, 224]): """Extract average square channel mean and variance of activations during forward pass to plot Signal Propogation Plots (SPP). @@ -90,8 +88,180 @@ def extract_spp_stats(model, Paper: https://arxiv.org/abs/2101.08692 Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 - """ + """ x = torch.normal(0., 1., input_shape) hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns) _ = model(x) return hook.stats + + +def freeze_batch_norm_2d(module): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = freeze_batch_norm_2d(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + +def unfreeze_batch_norm_2d(module): + """ + Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance + of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked + recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + if isinstance(module, FrozenBatchNorm2d): + res = torch.nn.BatchNorm2d(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = unfreeze_batch_norm_2d(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + +def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'): + """ + Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is + done in place. + Args: + root_module (nn.Module, optional): Root module relative to which the `submodules` are referenced. + submodules (list[str]): List of modules for which the parameters will be (un)frozen. They are to be provided as + named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list + means that the whole root module will be (un)frozen. Defaults to [] + include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm 2d layers. + Defaults to `True`. + mode (bool): Whether to freeze ("freeze") or unfreeze ("unfreeze"). Defaults to `"freeze"`. + """ + assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"' + + if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + # Raise assertion here because we can't convert it in place + raise AssertionError( + "You have provided a batch norm layer as the `root module`. Please use " + "`timm.utils.model.freeze_batch_norm_2d` or `timm.utils.model.unfreeze_batch_norm_2d` instead.") + + if isinstance(submodules, str): + submodules = [submodules] + + named_modules = submodules + submodules = [root_module.get_submodule(m) for m in submodules] + + if not(len(submodules)): + named_modules, submodules = list(zip(*root_module.named_children())) + + for n, m in zip(named_modules, submodules): + # (Un)freeze parameters + for p in m.parameters(): + p.requires_grad = False if mode == 'freeze' else True + if include_bn_running_stats: + # Helper to add submodule specified as a named_module + def _add_submodule(module, name, submodule): + split = name.rsplit('.', 1) + if len(split) > 1: + module.get_submodule(split[0]).add_module(split[1], submodule) + else: + module.add_module(name, submodule) + # Freeze batch norm + if mode == 'freeze': + res = freeze_batch_norm_2d(m) + # It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't + # convert it in place, but will return the converted result. In this case `res` holds the converted + # result and we may try to re-assign the named module + if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + _add_submodule(root_module, n, res) + # Unfreeze batch norm + else: + res = unfreeze_batch_norm_2d(m) + # Ditto. See note above in mode == 'freeze' branch + if isinstance(m, FrozenBatchNorm2d): + _add_submodule(root_module, n, res) + + +def freeze(root_module, submodules=[], include_bn_running_stats=True): + """ + Freeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place. + Args: + root_module (nn.Module): Root module relative to which `submodules` are referenced. + submodules (list[str]): List of modules for which the parameters will be frozen. They are to be provided as + named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list + means that the whole root module will be frozen. Defaults to `[]`. + include_bn_running_stats (bool): Whether to also freeze the running statistics of `BatchNorm2d` and + `SyncBatchNorm` layers. These will be converted to `FrozenBatchNorm2d` in place. Hint: During fine tuning, + it's good practice to freeze batch norm stats. And note that these are different to the affine parameters + which are just normal PyTorch parameters. Defaults to `True`. + + Hint: If you want to freeze batch norm ONLY, use `timm.utils.model.freeze_batch_norm_2d`. + + Examples:: + + >>> model = timm.create_model('resnet18') + >>> # Freeze up to and including layer2 + >>> submodules = [n for n, _ in model.named_children()] + >>> print(submodules) + ['conv1', 'bn1', 'act1', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'global_pool', 'fc'] + >>> freeze(model, submodules[:submodules.index('layer2') + 1]) + >>> # Check for yourself that it works as expected + >>> print(model.layer2[0].conv1.weight.requires_grad) + False + >>> print(model.layer3[0].conv1.weight.requires_grad) + True + >>> # Unfreeze + >>> unfreeze(model) + """ + _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="freeze") + + +def unfreeze(root_module, submodules=[], include_bn_running_stats=True): + """ + Unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place. + Args: + root_module (nn.Module): Root module relative to which `submodules` are referenced. + submodules (list[str]): List of submodules for which the parameters will be (un)frozen. They are to be provided + as named modules relative to the root module (accessible via `root_module.named_modules()`). An empty + list means that the whole root module will be unfrozen. Defaults to `[]`. + include_bn_running_stats (bool): Whether to also unfreeze the running statistics of `FrozenBatchNorm2d` layers. + These will be converted to `BatchNorm2d` in place. Defaults to `True`. + + See example in docstring for `freeze`. + """ + _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze") diff --git a/timm/version.py b/timm/version.py index 779b9fc3..2b8877c5 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.4.13' +__version__ = '0.5.0'