Traceability fix for vit models for some experiments

more_datasets
Ross Wightman 3 years ago
parent f658a72e72
commit 3478f1d7f1

@ -6,7 +6,7 @@ Based on the impl in https://github.com/google-research/vision_transformer
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import torch
from torch import nn as nn from torch import nn as nn
from .helpers import to_2tuple from .helpers import to_2tuple
@ -30,8 +30,8 @@ class PatchEmbed(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \ torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model {self.img_size[0]}.")
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}.")
x = self.proj(x) x = self.proj(x)
if self.flatten: if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = x.flatten(2).transpose(1, 2) # BCHW -> BNC

Loading…
Cancel
Save