Should have included Conv2d layers in original weight init. Lets see what the impact is...

pull/450/head
Ross Wightman 4 years ago
parent 4de57ccf01
commit cbcb76d72c

@ -476,7 +476,7 @@ class VisionTransformer(nn.Module):
def _init_weights_original(m: nn.Module, n: str = ''): def _init_weights_original(m: nn.Module, n: str = ''):
if isinstance(m, nn.Linear): if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02) trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None: if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)

Loading…
Cancel
Save