Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 1fa1e01646
commit ab13e12803

@ -13,24 +13,22 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# This source code is licensed under the MIT license
import itertools
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import torch.utils.checkpoint as checkpoint
from .features import FeatureInfo
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
from .features import FeatureInfo
from collections import OrderedDict
import torch.utils.checkpoint as checkpoint
from .pretrained import generate_default_cfgs
from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['DaViT']
@ -109,6 +107,8 @@ class PatchEmbed(nn.Module):
H, W = size
in_shape = x.shape
# norm_after variable used as a workaround to original len(x.shape) == 3
if self.norm_after == False:
B, HW, C = in_shape
x = self.norm(x)
@ -204,7 +204,7 @@ def window_partition(x : Tensor, window_size: int):
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows : Tensor, window_size: int, H: int, W: int):
"""
Args:
@ -216,19 +216,9 @@ def window_reverse(windows : Tensor, window_size: int, H: int, W: int):
x: (B, H, W, C)
"""
#B = torch.floor(torch.tensor(windows.size(dim=0)) / (H * W / window_size / window_size)).int().item()
x = windows.view(
torch.floor(torch.tensor(windows.size(dim=0)) / (H * W / window_size / window_size)).int().item(),
H // window_size,
W // window_size,
window_size,
window_size,
-1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(
torch.floor(torch.tensor(windows.size(dim=0)) / (H * W / window_size / window_size)).int().item(),
H,
W,
-1)
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x

Loading…
Cancel
Save