|
|
|
@ -25,6 +25,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
|
|
|
|
|
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple
|
|
|
|
|
from ._builder import build_model_with_cfg
|
|
|
|
|
from ._manipulate import checkpoint_seq
|
|
|
|
|
from ._pretrained import generate_default_cfgs
|
|
|
|
|
from ._registry import register_model
|
|
|
|
|
|
|
|
|
@ -498,6 +499,9 @@ class EfficientFormerV2Stage(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.downsample(x)
|
|
|
|
|
if self.grad_checkpointing:
|
|
|
|
|
x = checkpoint_seq(self.blocks, x)
|
|
|
|
|
else:
|
|
|
|
|
x = self.blocks(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
@ -508,6 +512,7 @@ class EfficientFormerV2(nn.Module):
|
|
|
|
|
depths,
|
|
|
|
|
in_chans=3,
|
|
|
|
|
img_size=224,
|
|
|
|
|
global_pool='avg',
|
|
|
|
|
embed_dims=None,
|
|
|
|
|
downsamples=None,
|
|
|
|
|
mlp_ratios=4,
|
|
|
|
@ -522,7 +527,9 @@ class EfficientFormerV2(nn.Module):
|
|
|
|
|
distillation=True,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert global_pool in ('avg', '')
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.global_pool = global_pool
|
|
|
|
|
self.feature_info = []
|
|
|
|
|
img_size = to_2tuple(img_size)
|
|
|
|
|
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
|
|
|
|
@ -583,11 +590,49 @@ class EfficientFormerV2(nn.Module):
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def no_weight_decay(self):
|
|
|
|
|
return {k for k, _ in self.named_parameters() if 'attention_biases' in k}
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def group_matcher(self, coarse=False):
|
|
|
|
|
matcher = dict(
|
|
|
|
|
stem=r'^stem', # stem and embed
|
|
|
|
|
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
|
|
|
|
|
)
|
|
|
|
|
return matcher
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def set_grad_checkpointing(self, enable=True):
|
|
|
|
|
for s in self.stages:
|
|
|
|
|
s.grad_checkpointing = enable
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def get_classifier(self):
|
|
|
|
|
return self.head, self.head_dist
|
|
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool=None):
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
if global_pool is not None:
|
|
|
|
|
self.global_pool = global_pool
|
|
|
|
|
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def set_distilled_training(self, enable=True):
|
|
|
|
|
self.distilled_training = enable
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
x = self.stem(x)
|
|
|
|
|
x = self.stages(x)
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
|
if self.global_pool == 'avg':
|
|
|
|
|
x = x.mean(dim=(2, 3))
|
|
|
|
|
if pre_logits:
|
|
|
|
|
return x
|
|
|
|
|
x, x_dist = self.head(x), self.head_dist(x)
|
|
|
|
|
if self.distilled_training and self.training and not torch.jit.is_scripting():
|
|
|
|
|
# only return separate classification predictions when training in distilled mode
|
|
|
|
@ -596,6 +641,11 @@ class EfficientFormerV2(nn.Module):
|
|
|
|
|
# during standard train/finetune, inference average the classifier predictions
|
|
|
|
|
return (x + x_dist) / 2
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.forward_features(x)
|
|
|
|
|
x = self.forward_head(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
|
return {
|
|
|
|
|