From 95ec255f7f6086946a8c9e8b2db0405bf28e2be9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 3 Feb 2023 21:21:23 -0800 Subject: [PATCH] Finish timm mode api for efficientformer_v2, add grad checkpointing support to both efficientformers --- timm/models/efficientformer.py | 6 +++- timm/models/efficientformer_v2.py | 56 +++++++++++++++++++++++++++++-- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 665a37ce..21957d58 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -20,6 +20,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq from ._pretrained import generate_default_cfgs from ._registry import register_model @@ -335,7 +336,10 @@ class EfficientFormerStage(nn.Module): def forward(self, x): x = self.downsample(x) - x = self.blocks(x) + if self.grad_checkpointing: + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) return x diff --git a/timm/models/efficientformer_v2.py b/timm/models/efficientformer_v2.py index e51394b6..e2adccdb 100644 --- a/timm/models/efficientformer_v2.py +++ b/timm/models/efficientformer_v2.py @@ -25,6 +25,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq from ._pretrained import generate_default_cfgs from ._registry import register_model @@ -498,7 +499,10 @@ class EfficientFormerV2Stage(nn.Module): def forward(self, x): x = self.downsample(x) - x = self.blocks(x) + if self.grad_checkpointing: + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) return x @@ -508,6 +512,7 @@ class EfficientFormerV2(nn.Module): depths, in_chans=3, img_size=224, + global_pool='avg', embed_dims=None, downsamples=None, mlp_ratios=4, @@ -522,7 +527,9 @@ class EfficientFormerV2(nn.Module): distillation=True, ): super().__init__() + assert global_pool in ('avg', '') self.num_classes = num_classes + self.global_pool = global_pool self.feature_info = [] img_size = to_2tuple(img_size) norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps) @@ -583,11 +590,49 @@ class EfficientFormerV2(nn.Module): if m.bias is not None: nn.init.constant_(m.bias, 0) - def forward(self, x): + @torch.jit.ignore + def no_weight_decay(self): + return {k for k, _ in self.named_parameters() if 'attention_biases' in k} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', # stem and embed + blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + @torch.jit.ignore + def set_distilled_training(self, enable=True): + self.distilled_training = enable + + def forward_features(self, x): x = self.stem(x) x = self.stages(x) x = self.norm(x) - x = x.mean(dim=(2, 3)) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean(dim=(2, 3)) + if pre_logits: + return x x, x_dist = self.head(x), self.head_dist(x) if self.distilled_training and self.training and not torch.jit.is_scripting(): # only return separate classification predictions when training in distilled mode @@ -596,6 +641,11 @@ class EfficientFormerV2(nn.Module): # during standard train/finetune, inference average the classifier predictions return (x + x_dist) / 2 + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + def _cfg(url='', **kwargs): return {