diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 2da323cf..814b6957 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -449,13 +449,12 @@ class EfficientFormer(nn.Module): def get_classifier(self): return self.head, self.head_dist - def reset_classifier(self, num_classes, global_pool=None, distillation=None): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - if self.dist: - self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() @torch.jit.ignore def set_distilled_training(self, enable=True): diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index c134b7c2..e7eccea8 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -427,6 +427,7 @@ class GlobalContextVit(nn.Module): feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4 self.global_pool = global_pool self.num_classes = num_classes + self.drop_rate = drop_rate num_stages = len(depths) self.num_features = int(embed_dim * 2 ** (num_stages - 1)) @@ -491,7 +492,7 @@ class GlobalContextVit(nn.Module): def group_matcher(self, coarse=False): matcher = dict( stem=r'^stem', # stem and embed - blocks=(r'^stages\.(\d+)', None) + blocks=r'^stages\.(\d+)' ) return matcher @@ -500,6 +501,16 @@ class GlobalContextVit(nn.Module): for s in self.stages: s.grad_checkpointing = enable + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is None: + global_pool = self.head.global_pool.pool_type + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.stem(x) x = self.stages(x) diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index 4c7cd044..fc29f113 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -850,7 +850,7 @@ class MultiScaleVit(nn.Module): @torch.jit.ignore def group_matcher(self, coarse=False): matcher = dict( - stem=r'^stem', # stem and embed + stem=r'^patch_embed', # stem and embed blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] ) return matcher @@ -862,7 +862,7 @@ class MultiScaleVit(nn.Module): @torch.jit.ignore def get_classifier(self): - return self.head + return self.head.fc def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 1f698fbc..ce4cbf56 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -351,7 +351,7 @@ class PyramidVisionTransformerV2(nn.Module): def group_matcher(self, coarse=False): matcher = dict( stem=r'^patch_embed', # stem and embed - blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + blocks=r'^stages\.(\d+)' ) return matcher