Fix some model support functions

pull/1415/head
Ross Wightman 2 years ago
parent f332fc2db7
commit ca52108c2b

@ -449,13 +449,12 @@ class EfficientFormer(nn.Module):
def get_classifier(self):
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
if self.dist:
self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
@torch.jit.ignore
def set_distilled_training(self, enable=True):

@ -427,6 +427,7 @@ class GlobalContextVit(nn.Module):
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
self.global_pool = global_pool
self.num_classes = num_classes
self.drop_rate = drop_rate
num_stages = len(depths)
self.num_features = int(embed_dim * 2 ** (num_stages - 1))
@ -491,7 +492,7 @@ class GlobalContextVit(nn.Module):
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem', # stem and embed
blocks=(r'^stages\.(\d+)', None)
blocks=r'^stages\.(\d+)'
)
return matcher
@ -500,6 +501,16 @@ class GlobalContextVit(nn.Module):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is None:
global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.stem(x)
x = self.stages(x)

@ -850,7 +850,7 @@ class MultiScaleVit(nn.Module):
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem', # stem and embed
stem=r'^patch_embed', # stem and embed
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
)
return matcher
@ -862,7 +862,7 @@ class MultiScaleVit(nn.Module):
@torch.jit.ignore
def get_classifier(self):
return self.head
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes

@ -351,7 +351,7 @@ class PyramidVisionTransformerV2(nn.Module):
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^patch_embed', # stem and embed
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
blocks=r'^stages\.(\d+)'
)
return matcher

Loading…
Cancel
Save