diff --git a/timm/models/davit.py b/timm/models/davit.py index 46a97648..59850ff1 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -13,24 +13,22 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig # This source code is licensed under the MIT license import itertools - from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List - - +from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor +import torch.utils.checkpoint as checkpoint + +from .features import FeatureInfo +from .fx_features import register_notrace_function from .helpers import build_model_with_cfg, pretrained_cfg_for_features -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp -from .features import FeatureInfo -from collections import OrderedDict -import torch.utils.checkpoint as checkpoint from .pretrained import generate_default_cfgs from .registry import register_model - +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['DaViT'] @@ -109,6 +107,8 @@ class PatchEmbed(nn.Module): H, W = size in_shape = x.shape + + # norm_after variable used as a workaround to original len(x.shape) == 3 if self.norm_after == False: B, HW, C = in_shape x = self.norm(x) @@ -204,7 +204,7 @@ def window_partition(x : Tensor, window_size: int): windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows - +@register_notrace_function # reason: int argument is a Proxy def window_reverse(windows : Tensor, window_size: int, H: int, W: int): """ Args: @@ -216,19 +216,9 @@ def window_reverse(windows : Tensor, window_size: int, H: int, W: int): x: (B, H, W, C) """ - #B = torch.floor(torch.tensor(windows.size(dim=0)) / (H * W / window_size / window_size)).int().item() - x = windows.view( - torch.floor(torch.tensor(windows.size(dim=0)) / (H * W / window_size / window_size)).int().item(), - H // window_size, - W // window_size, - window_size, - window_size, - -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view( - torch.floor(torch.tensor(windows.size(dim=0)) / (H * W / window_size / window_size)).int().item(), - H, - W, - -1) + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x