|
|
|
@ -13,6 +13,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from .helpers import load_pretrained, build_model_with_cfg
|
|
|
|
|
from .layers import DropPath, Grapher
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
from .fx_features import register_notrace_function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
@ -172,6 +173,7 @@ class DeepGCN(torch.nn.Module):
|
|
|
|
|
m.bias.data.zero_()
|
|
|
|
|
m.bias.requires_grad = True
|
|
|
|
|
|
|
|
|
|
@register_notrace_function # reason: int argument is a Proxy
|
|
|
|
|
def _get_pos_embed(self, pos_embed, H, W):
|
|
|
|
|
if pos_embed is None or (H == pos_embed.size(-2) and W == pos_embed.size(-1)):
|
|
|
|
|
return pos_embed
|
|
|
|
|