Make all models FX traceable

pull/800/head
Alexander Soare 3 years ago
parent cf4561ca72
commit e051dce354

@ -22,6 +22,7 @@ import torch.nn.functional as F
from .helpers import to_2tuple, make_divisible from .helpers import to_2tuple, make_divisible
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_
from timm.models.fx_helpers import fx_and
def rel_logits_1d(q, rel_k, permute_mask: List[int]): def rel_logits_1d(q, rel_k, permute_mask: List[int]):

@ -24,7 +24,6 @@ import torch.nn.functional as F
from .helpers import make_divisible from .helpers import make_divisible
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_
from timm.models.fx_helpers import
def rel_logits_1d(q, rel_k, permute_mask: List[int]): def rel_logits_1d(q, rel_k, permute_mask: List[int]):

@ -10,6 +10,7 @@ from torch.nn import functional as F
from .conv_bn_act import ConvBnAct from .conv_bn_act import ConvBnAct
from .helpers import make_divisible from .helpers import make_divisible
from timm.models.fx_helpers import fx_and
class NonLocalAttn(nn.Module): class NonLocalAttn(nn.Module):

@ -12,6 +12,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import build_model_with_cfg from timm.models.helpers import build_model_with_cfg
from timm.models.fx_helpers import fx_and
from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers import Mlp, DropPath, trunc_normal_
from timm.models.layers.helpers import to_2tuple from timm.models.layers.helpers import to_2tuple
from timm.models.registry import register_model from timm.models.registry import register_model

Loading…
Cancel
Save