|
|
|
@ -251,7 +251,7 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
def no_weight_decay(self):
|
|
|
|
|
return {'pos_embed', 'cls_token'}
|
|
|
|
|
|
|
|
|
|
def forward(self, x, attn_mask=None):
|
|
|
|
|
def forward_features(self, x, attn_mask=None):
|
|
|
|
|
B = x.shape[0]
|
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
|
|
|
|
@ -263,6 +263,10 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
x = blk(x, attn_mask=attn_mask)
|
|
|
|
|
|
|
|
|
|
x = self.norm(x[:, 0])
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward(self, x, attn_mask=None):
|
|
|
|
|
x = self.forward_features(x, attn_mask)
|
|
|
|
|
x = self.head(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|