Finish timm mode api for efficientformer_v2, add grad checkpointing support to both efficientformers

pull/1655/head
Ross Wightman 2 years ago
parent 9d03c6f526
commit 95ec255f7f

@ -20,6 +20,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs from ._pretrained import generate_default_cfgs
from ._registry import register_model from ._registry import register_model
@ -335,7 +336,10 @@ class EfficientFormerStage(nn.Module):
def forward(self, x): def forward(self, x):
x = self.downsample(x) x = self.downsample(x)
x = self.blocks(x) if self.grad_checkpointing:
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x return x

@ -25,6 +25,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs from ._pretrained import generate_default_cfgs
from ._registry import register_model from ._registry import register_model
@ -498,7 +499,10 @@ class EfficientFormerV2Stage(nn.Module):
def forward(self, x): def forward(self, x):
x = self.downsample(x) x = self.downsample(x)
x = self.blocks(x) if self.grad_checkpointing:
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x return x
@ -508,6 +512,7 @@ class EfficientFormerV2(nn.Module):
depths, depths,
in_chans=3, in_chans=3,
img_size=224, img_size=224,
global_pool='avg',
embed_dims=None, embed_dims=None,
downsamples=None, downsamples=None,
mlp_ratios=4, mlp_ratios=4,
@ -522,7 +527,9 @@ class EfficientFormerV2(nn.Module):
distillation=True, distillation=True,
): ):
super().__init__() super().__init__()
assert global_pool in ('avg', '')
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool
self.feature_info = [] self.feature_info = []
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps) norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
@ -583,11 +590,49 @@ class EfficientFormerV2(nn.Module):
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def forward(self, x): @torch.jit.ignore
def no_weight_decay(self):
return {k for k, _ in self.named_parameters() if 'attention_biases' in k}
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem', # stem and embed
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
@torch.jit.ignore
def set_distilled_training(self, enable=True):
self.distilled_training = enable
def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
x = self.stages(x) x = self.stages(x)
x = self.norm(x) x = self.norm(x)
x = x.mean(dim=(2, 3)) return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=(2, 3))
if pre_logits:
return x
x, x_dist = self.head(x), self.head_dist(x) x, x_dist = self.head(x), self.head_dist(x)
if self.distilled_training and self.training and not torch.jit.is_scripting(): if self.distilled_training and self.training and not torch.jit.is_scripting():
# only return separate classification predictions when training in distilled mode # only return separate classification predictions when training in distilled mode
@ -596,6 +641,11 @@ class EfficientFormerV2(nn.Module):
# during standard train/finetune, inference average the classifier predictions # during standard train/finetune, inference average the classifier predictions
return (x + x_dist) / 2 return (x + x_dist) / 2
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
return { return {

Loading…
Cancel
Save