|
|
|
@ -546,7 +546,7 @@ class MetaFormer(nn.Module):
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def set_grad_checkpointing(self, enable=True):
|
|
|
|
|
print("not implemented")
|
|
|
|
|
self.grad_checkpointing = enable
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def get_classifier(self):
|
|
|
|
@ -574,7 +574,10 @@ class MetaFormer(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
x = self.stem(x)
|
|
|
|
|
x = self.stages(x)
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
|
x = checkpoint_seq(self.stages, x)
|
|
|
|
|
else:
|
|
|
|
|
x = self.stages(x)
|
|
|
|
|
x = self.norm_pre(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
@ -583,6 +586,8 @@ class MetaFormer(nn.Module):
|
|
|
|
|
x = self.forward_head(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
# FIXME convert to group matcher
|
|
|
|
|
# this works but it's long and breaks backwards compatability with weights from the poolformer-only impl
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
import re
|
|
|
|
|
out_dict = {}
|
|
|
|
@ -817,6 +822,7 @@ default_cfgs = generate_default_cfgs({
|
|
|
|
|
classifier='head.fc.fc2', num_classes=21841),
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
# FIXME fully merge poolformerv1, rename to poolformer to succeed poolformer.py
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def poolformerv1_s12(pretrained=False, **kwargs):
|
|
|
|
|