From f080f1b8c9c1de58477e4b94939989d2a68663d6 Mon Sep 17 00:00:00 2001 From: iamhankai Date: Tue, 6 Dec 2022 12:18:50 +0800 Subject: [PATCH] update ViG models --- tests/test_models.py | 2 +- timm/models/vision_gnn.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index d007d65a..1d619c62 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -28,7 +28,7 @@ NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', + 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'pvig_*', ] NUM_NON_STD = len(NON_STD_FILTERS) diff --git a/timm/models/vision_gnn.py b/timm/models/vision_gnn.py index 5e48ef9d..4fe4a8f6 100755 --- a/timm/models/vision_gnn.py +++ b/timm/models/vision_gnn.py @@ -114,7 +114,6 @@ class DeepGCN(torch.nn.Module): super(DeepGCN, self).__init__() self.num_classes = num_classes self.in_chans = in_chans - print(opt) k = opt.k act_layer = nn.GELU norm = opt.norm @@ -168,9 +167,16 @@ class DeepGCN(torch.nn.Module): m.bias.data.zero_() m.bias.requires_grad = True + def _get_pos_embed(self, pos_embed, H, W): + if pos_embed is None or (H == pos_embed.size(-2) and W == pos_embed.size(-1)): + return pos_embed + else: + return F.interpolate(pos_embed, size=(H, W), mode="bicubic") + def forward(self, inputs): - x = self.stem(inputs) + self.pos_embed + x = self.stem(inputs) B, C, H, W = x.shape + x = x + self._get_pos_embed(self.pos_embed, H, W) for i in range(len(self.backbone)): x = self.backbone[i](x)