update ViG models

pull/1578/head
iamhankai 3 years ago
parent cb725a85a2
commit f080f1b8c9

@ -28,7 +28,7 @@ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*',
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'pvig_*',
]
NUM_NON_STD = len(NON_STD_FILTERS)

@ -114,7 +114,6 @@ class DeepGCN(torch.nn.Module):
super(DeepGCN, self).__init__()
self.num_classes = num_classes
self.in_chans = in_chans
print(opt)
k = opt.k
act_layer = nn.GELU
norm = opt.norm
@ -168,9 +167,16 @@ class DeepGCN(torch.nn.Module):
m.bias.data.zero_()
m.bias.requires_grad = True
def _get_pos_embed(self, pos_embed, H, W):
if pos_embed is None or (H == pos_embed.size(-2) and W == pos_embed.size(-1)):
return pos_embed
else:
return F.interpolate(pos_embed, size=(H, W), mode="bicubic")
def forward(self, inputs):
x = self.stem(inputs) + self.pos_embed
x = self.stem(inputs)
B, C, H, W = x.shape
x = x + self._get_pos_embed(self.pos_embed, H, W)
for i in range(len(self.backbone)):
x = self.backbone[i](x)

Loading…
Cancel
Save