|
|
|
@ -12,7 +12,6 @@ import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from timm.models.helpers import build_model_with_cfg
|
|
|
|
|
from timm.models.fx_helpers import fx_and
|
|
|
|
|
from timm.models.layers import Mlp, DropPath, trunc_normal_
|
|
|
|
|
from timm.models.layers.helpers import to_2tuple
|
|
|
|
|
from timm.models.registry import register_model
|
|
|
|
@ -138,7 +137,9 @@ class PixelEmbed(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, x, pixel_pos):
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
torch._assert(fx_and(H == self.img_size[0], W == self.img_size[1]),
|
|
|
|
|
torch._assert(H == self.img_size[0],
|
|
|
|
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
|
|
|
|
torch._assert(W == self.img_size[1],
|
|
|
|
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
x = self.unfold(x)
|
|
|
|
|