|
|
@ -6,7 +6,7 @@ Based on the impl in https://github.com/google-research/vision_transformer
|
|
|
|
|
|
|
|
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch import nn as nn
|
|
|
|
from torch import nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
from .helpers import to_2tuple
|
|
|
|
from .helpers import to_2tuple
|
|
|
@ -30,8 +30,8 @@ class PatchEmbed(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
assert H == self.img_size[0] and W == self.img_size[1], \
|
|
|
|
torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model {self.img_size[0]}.")
|
|
|
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
|
|
|
torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}.")
|
|
|
|
x = self.proj(x)
|
|
|
|
x = self.proj(x)
|
|
|
|
if self.flatten:
|
|
|
|
if self.flatten:
|
|
|
|
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
|
|
|
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
|
|
|