|
|
@ -21,10 +21,10 @@ import torch.nn as nn
|
|
|
|
import torch.utils.checkpoint as checkpoint
|
|
|
|
import torch.utils.checkpoint as checkpoint
|
|
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from .fx_features import register_autowrap_function
|
|
|
|
from .fx_features import register_notrace_function
|
|
|
|
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
|
|
|
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
|
|
|
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
|
|
|
|
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
|
|
|
|
from .layers.trace_utils import _assert
|
|
|
|
from .layers import _assert
|
|
|
|
from .registry import register_model
|
|
|
|
from .registry import register_model
|
|
|
|
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
|
|
|
|
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
|
|
|
|
|
|
|
|
|
|
|
@ -103,7 +103,7 @@ def window_partition(x, window_size: int):
|
|
|
|
return windows
|
|
|
|
return windows
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_autowrap_function # reason: int argument is a Proxy
|
|
|
|
@register_notrace_function # reason: int argument is a Proxy
|
|
|
|
def window_reverse(windows, window_size: int, H: int, W: int):
|
|
|
|
def window_reverse(windows, window_size: int, H: int, W: int):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|