Fix formatting of last commit

more_datasets
Ross Wightman 3 years ago
parent 3478f1d7f1
commit b745d30a3e

@ -30,8 +30,8 @@ class PatchEmbed(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model {self.img_size[0]}.") torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}.") torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
x = self.proj(x) x = self.proj(x)
if self.flatten: if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = x.flatten(2).transpose(1, 2) # BCHW -> BNC

Loading…
Cancel
Save