diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 5fb5c7c7..42943fab 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -476,7 +476,7 @@ class VisionTransformer(nn.Module): def _init_weights_original(m: nn.Module, n: str = ''): - if isinstance(m, nn.Linear): + if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0)