|
|
|
@ -13,24 +13,22 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
|
|
|
|
|
# This source code is licensed under the MIT license
|
|
|
|
|
|
|
|
|
|
import itertools
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from torch import Tensor
|
|
|
|
|
import torch.utils.checkpoint as checkpoint
|
|
|
|
|
|
|
|
|
|
from .features import FeatureInfo
|
|
|
|
|
from .fx_features import register_notrace_function
|
|
|
|
|
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
|
|
|
|
|
from .features import FeatureInfo
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
import torch.utils.checkpoint as checkpoint
|
|
|
|
|
from .pretrained import generate_default_cfgs
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
|
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
|
|
|
|
|
@ -109,6 +107,8 @@ class PatchEmbed(nn.Module):
|
|
|
|
|
H, W = size
|
|
|
|
|
in_shape = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# norm_after variable used as a workaround to original len(x.shape) == 3
|
|
|
|
|
if self.norm_after == False:
|
|
|
|
|
B, HW, C = in_shape
|
|
|
|
|
x = self.norm(x)
|
|
|
|
@ -204,7 +204,7 @@ def window_partition(x : Tensor, window_size: int):
|
|
|
|
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
|
|
|
|
return windows
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_notrace_function # reason: int argument is a Proxy
|
|
|
|
|
def window_reverse(windows : Tensor, window_size: int, H: int, W: int):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
@ -216,19 +216,9 @@ def window_reverse(windows : Tensor, window_size: int, H: int, W: int):
|
|
|
|
|
x: (B, H, W, C)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
#B = torch.floor(torch.tensor(windows.size(dim=0)) / (H * W / window_size / window_size)).int().item()
|
|
|
|
|
x = windows.view(
|
|
|
|
|
torch.floor(torch.tensor(windows.size(dim=0)) / (H * W / window_size / window_size)).int().item(),
|
|
|
|
|
H // window_size,
|
|
|
|
|
W // window_size,
|
|
|
|
|
window_size,
|
|
|
|
|
window_size,
|
|
|
|
|
-1)
|
|
|
|
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(
|
|
|
|
|
torch.floor(torch.tensor(windows.size(dim=0)) / (H * W / window_size / window_size)).int().item(),
|
|
|
|
|
H,
|
|
|
|
|
W,
|
|
|
|
|
-1)
|
|
|
|
|
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
|
|
|
|
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
|
|
|
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|