|
|
|
@ -427,6 +427,7 @@ class GlobalContextVit(nn.Module):
|
|
|
|
|
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
|
|
|
|
|
self.global_pool = global_pool
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
|
num_stages = len(depths)
|
|
|
|
|
self.num_features = int(embed_dim * 2 ** (num_stages - 1))
|
|
|
|
|
|
|
|
|
@ -491,7 +492,7 @@ class GlobalContextVit(nn.Module):
|
|
|
|
|
def group_matcher(self, coarse=False):
|
|
|
|
|
matcher = dict(
|
|
|
|
|
stem=r'^stem', # stem and embed
|
|
|
|
|
blocks=(r'^stages\.(\d+)', None)
|
|
|
|
|
blocks=r'^stages\.(\d+)'
|
|
|
|
|
)
|
|
|
|
|
return matcher
|
|
|
|
|
|
|
|
|
@ -500,6 +501,16 @@ class GlobalContextVit(nn.Module):
|
|
|
|
|
for s in self.stages:
|
|
|
|
|
s.grad_checkpointing = enable
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def get_classifier(self):
|
|
|
|
|
return self.head.fc
|
|
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool=None):
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
if global_pool is None:
|
|
|
|
|
global_pool = self.head.global_pool.pool_type
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
x = self.stem(x)
|
|
|
|
|
x = self.stages(x)
|
|
|
|
|