|
|
|
@ -21,7 +21,7 @@ def _cfg(url='', **kwargs):
|
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
|
|
|
'crop_pct': .9, 'interpolation': 'bicubic',
|
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
|
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
|
|
|
|
'first_conv': 'stem.convs.0', 'classifier': 'prediction.4',
|
|
|
|
|
'min_input_size': (3, 224, 224),
|
|
|
|
|
**kwargs
|
|
|
|
|
}
|
|
|
|
@ -123,10 +123,11 @@ class DeepGCN(torch.nn.Module):
|
|
|
|
|
stochastic = opt.use_stochastic
|
|
|
|
|
conv = opt.conv
|
|
|
|
|
drop_path = opt.drop_path
|
|
|
|
|
channels = opt.channels
|
|
|
|
|
self.num_features = channels[-1] # num_features for consistency with other models
|
|
|
|
|
|
|
|
|
|
blocks = opt.blocks
|
|
|
|
|
self.n_blocks = sum(blocks)
|
|
|
|
|
channels = opt.channels
|
|
|
|
|
reduce_ratios = [4, 2, 1, 1]
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path, self.n_blocks)] # stochastic depth decay
|
|
|
|
|
num_knn = [int(x.item()) for x in torch.linspace(k, k, self.n_blocks)] # number of knn's k
|
|
|
|
@ -152,11 +153,14 @@ class DeepGCN(torch.nn.Module):
|
|
|
|
|
idx += 1
|
|
|
|
|
self.backbone = Seq(*self.backbone)
|
|
|
|
|
|
|
|
|
|
self.prediction = Seq(nn.Conv2d(channels[-1], 1024, 1, bias=True),
|
|
|
|
|
if num_classes > 0:
|
|
|
|
|
self.prediction = Seq(nn.Conv2d(self.num_features, 1024, 1, bias=True),
|
|
|
|
|
nn.BatchNorm2d(1024),
|
|
|
|
|
act_layer(),
|
|
|
|
|
nn.Dropout(opt.dropout),
|
|
|
|
|
nn.Conv2d(1024, num_classes, 1, bias=True))
|
|
|
|
|
else:
|
|
|
|
|
self.prediction = nn.Identity()
|
|
|
|
|
self.model_init()
|
|
|
|
|
|
|
|
|
|
def model_init(self):
|
|
|
|
@ -174,13 +178,30 @@ class DeepGCN(torch.nn.Module):
|
|
|
|
|
else:
|
|
|
|
|
return F.interpolate(pos_embed, size=(H, W), mode="bicubic")
|
|
|
|
|
|
|
|
|
|
def forward(self, inputs):
|
|
|
|
|
x = self.stem(inputs)
|
|
|
|
|
def reset_classifier(self, num_classes: int, global_pool=None):
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
if global_pool is not None:
|
|
|
|
|
assert global_pool in ('', 'avg', 'token')
|
|
|
|
|
self.global_pool = global_pool
|
|
|
|
|
if num_classes > 0:
|
|
|
|
|
self.prediction = Seq(nn.Conv2d(self.num_features, 1024, 1, bias=True),
|
|
|
|
|
nn.BatchNorm2d(1024),
|
|
|
|
|
act_layer(),
|
|
|
|
|
nn.Dropout(opt.dropout),
|
|
|
|
|
nn.Conv2d(1024, num_classes, 1, bias=True))
|
|
|
|
|
else:
|
|
|
|
|
self.prediction = nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
x = self.stem(x)
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
x = x + self._get_pos_embed(self.pos_embed, H, W)
|
|
|
|
|
for i in range(len(self.backbone)):
|
|
|
|
|
x = self.backbone[i](x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.forward_features(x)
|
|
|
|
|
x = F.adaptive_avg_pool2d(x, 1)
|
|
|
|
|
return self.prediction(x).squeeze(-1).squeeze(-1)
|
|
|
|
|
|
|
|
|
@ -200,7 +221,7 @@ def _create_pvig(variant, opt, pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def pvig_ti_224_gelu(pretrained=False, **kwargs):
|
|
|
|
|
def pvig_ti_224_gelu(pretrained=False, num_classes=1000, **kwargs):
|
|
|
|
|
class OptInit:
|
|
|
|
|
def __init__(self, drop_path_rate=0.0, **kwargs):
|
|
|
|
|
self.k = 9 # neighbor num (default:9)
|
|
|
|
@ -216,13 +237,13 @@ def pvig_ti_224_gelu(pretrained=False, **kwargs):
|
|
|
|
|
self.channels = [48, 96, 240, 384] # number of channels of deep features
|
|
|
|
|
|
|
|
|
|
opt = OptInit(**kwargs)
|
|
|
|
|
model = _create_pvig('pvig_ti_224_gelu', opt, pretrained)
|
|
|
|
|
model = _create_pvig('pvig_ti_224_gelu', opt, pretrained, num_classes=num_classes)
|
|
|
|
|
model.default_cfg = default_cfgs['pvig_ti_224_gelu']
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def pvig_s_224_gelu(pretrained=False, **kwargs):
|
|
|
|
|
def pvig_s_224_gelu(pretrained=False, num_classes=1000, **kwargs):
|
|
|
|
|
class OptInit:
|
|
|
|
|
def __init__(self, drop_path_rate=0.0, **kwargs):
|
|
|
|
|
self.k = 9 # neighbor num (default:9)
|
|
|
|
@ -238,13 +259,13 @@ def pvig_s_224_gelu(pretrained=False, **kwargs):
|
|
|
|
|
self.channels = [80, 160, 400, 640] # number of channels of deep features
|
|
|
|
|
|
|
|
|
|
opt = OptInit(**kwargs)
|
|
|
|
|
model = _create_pvig('pvig_s_224_gelu', opt, pretrained)
|
|
|
|
|
model = _create_pvig('pvig_s_224_gelu', opt, pretrained, num_classes=num_classes)
|
|
|
|
|
model.default_cfg = default_cfgs['pvig_s_224_gelu']
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def pvig_m_224_gelu(pretrained=False, **kwargs):
|
|
|
|
|
def pvig_m_224_gelu(pretrained=False, num_classes=1000, **kwargs):
|
|
|
|
|
class OptInit:
|
|
|
|
|
def __init__(self, drop_path_rate=0.0, **kwargs):
|
|
|
|
|
self.k = 9 # neighbor num (default:9)
|
|
|
|
@ -260,13 +281,13 @@ def pvig_m_224_gelu(pretrained=False, **kwargs):
|
|
|
|
|
self.channels = [96, 192, 384, 768] # number of channels of deep features
|
|
|
|
|
|
|
|
|
|
opt = OptInit(**kwargs)
|
|
|
|
|
model = _create_pvig('pvig_m_224_gelu', opt, pretrained)
|
|
|
|
|
model = _create_pvig('pvig_m_224_gelu', opt, pretrained, num_classes=num_classes)
|
|
|
|
|
model.default_cfg = default_cfgs['pvig_m_224_gelu']
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def pvig_b_224_gelu(pretrained=False, **kwargs):
|
|
|
|
|
def pvig_b_224_gelu(pretrained=False, num_classes=1000, **kwargs):
|
|
|
|
|
class OptInit:
|
|
|
|
|
def __init__(self, drop_path_rate=0.0, **kwargs):
|
|
|
|
|
self.k = 9 # neighbor num (default:9)
|
|
|
|
@ -282,6 +303,6 @@ def pvig_b_224_gelu(pretrained=False, **kwargs):
|
|
|
|
|
self.channels = [128, 256, 512, 1024] # number of channels of deep features
|
|
|
|
|
|
|
|
|
|
opt = OptInit(**kwargs)
|
|
|
|
|
model = _create_pvig('pvig_b_224_gelu', opt, pretrained)
|
|
|
|
|
model = _create_pvig('pvig_b_224_gelu', opt, pretrained, num_classes=num_classes)
|
|
|
|
|
model.default_cfg = default_cfgs['pvig_b_224_gelu']
|
|
|
|
|
return model
|
|
|
|
|