|
|
@ -476,7 +476,7 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights_original(m: nn.Module, n: str = ''):
|
|
|
|
def _init_weights_original(m: nn.Module, n: str = ''):
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|