wip - attempting to rebase

pull/800/head
Alexander Soare 3 years ago
parent 02c3a75a45
commit 0149ec30d7

@ -24,7 +24,7 @@ import torch.nn.functional as F
from torch.fx.graph_module import _copy_attr from torch.fx.graph_module import _copy_attr
from .features import _get_feature_info from .features import _get_feature_info
from .fx_helpers import fx_and, fx_float_to_int from .fx_helpers import fx_float_to_int
# Layers we went to treat as leaf modules for FeatureGraphNet # Layers we went to treat as leaf modules for FeatureGraphNet
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame
@ -55,7 +55,7 @@ def register_leaf_module(module: nn.Module):
# These functions will not be traced through # These functions will not be traced through
_autowrap_functions=(fx_float_to_int, fx_and) _autowrap_functions=(fx_float_to_int,)
class TimmTracer(fx.Tracer): class TimmTracer(fx.Tracer):

@ -1,14 +1,4 @@
def fx_and(a: bool, b: bool) -> bool:
"""
Symbolic tracing helper to substitute for normal usage of `* and *` within `torch._assert`.
Hint: Symbolic tracing does not support control flow but since an `assert` is either a dead-end or not, this hack
is okay.
"""
return (a and b)
def fx_float_to_int(x: float) -> int: def fx_float_to_int(x: float) -> int:
""" """
Symbolic tracing helper to substitute for inbuilt `int`. Symbolic tracing helper to substitute for inbuilt `int`.

@ -22,7 +22,6 @@ import torch.nn.functional as F
from .helpers import to_2tuple, make_divisible from .helpers import to_2tuple, make_divisible
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_
from timm.models.fx_helpers import fx_and
def rel_logits_1d(q, rel_k, permute_mask: List[int]): def rel_logits_1d(q, rel_k, permute_mask: List[int]):

@ -24,7 +24,7 @@ import torch.nn.functional as F
from .helpers import make_divisible from .helpers import make_divisible
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_
from timm.models.fx_helpers import fx_and from timm.models.fx_helpers import
def rel_logits_1d(q, rel_k, permute_mask: List[int]): def rel_logits_1d(q, rel_k, permute_mask: List[int]):

@ -10,7 +10,6 @@ from torch.nn import functional as F
from .conv_bn_act import ConvBnAct from .conv_bn_act import ConvBnAct
from .helpers import make_divisible from .helpers import make_divisible
from timm.models.fx_helpers import fx_and
class NonLocalAttn(nn.Module): class NonLocalAttn(nn.Module):
@ -96,7 +95,8 @@ class BilinearAttnTransform(nn.Module):
return x return x
def forward(self, x): def forward(self, x):
torch._assert(fx_and(x.shape[-1] % self.block_size == 0, x.shape[-2] % self.block_size == 0), '') torch._assert(x.shape[-1] % self.block_size == 0, '')
torch._assert(x.shape[-2] % self.block_size == 0, '')
B, C, H, W = x.shape B, C, H, W = x.shape
out = self.conv1(x) out = self.conv1(x)
rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) rp = F.adaptive_max_pool2d(out, (self.block_size, 1))

@ -9,11 +9,7 @@ Hacked together by / Copyright 2020 Ross Wightman
from torch import nn as nn from torch import nn as nn
from .helpers import to_2tuple from .helpers import to_2tuple
<<<<<<< HEAD
from .trace_utils import _assert from .trace_utils import _assert
=======
from timm.models.fx_helpers import fx_and
>>>>>>> Make all models FX traceable
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):

@ -12,7 +12,6 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import build_model_with_cfg from timm.models.helpers import build_model_with_cfg
from timm.models.fx_helpers import fx_and
from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers import Mlp, DropPath, trunc_normal_
from timm.models.layers.helpers import to_2tuple from timm.models.layers.helpers import to_2tuple
from timm.models.registry import register_model from timm.models.registry import register_model
@ -138,7 +137,9 @@ class PixelEmbed(nn.Module):
def forward(self, x, pixel_pos): def forward(self, x, pixel_pos):
B, C, H, W = x.shape B, C, H, W = x.shape
torch._assert(fx_and(H == self.img_size[0], W == self.img_size[1]), torch._assert(H == self.img_size[0],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
torch._assert(W == self.img_size[1],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
x = self.proj(x) x = self.proj(x)
x = self.unfold(x) x = self.unfold(x)

Loading…
Cancel
Save