diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py index 42997fb8..b0b759c7 100644 --- a/timm/models/layers/patch_embed.py +++ b/timm/models/layers/patch_embed.py @@ -6,7 +6,7 @@ Based on the impl in https://github.com/google-research/vision_transformer Hacked together by / Copyright 2020 Ross Wightman """ - +import torch from torch import nn as nn from .helpers import to_2tuple @@ -30,8 +30,8 @@ class PatchEmbed(nn.Module): def forward(self, x): B, C, H, W = x.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model {self.img_size[0]}.") + torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}.") x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC