@ -476,7 +476,7 @@ class VisionTransformer(nn.Module):
def _init_weights_original(m: nn.Module, n: str = ''):
if isinstance(m, nn.Linear):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)