From 8e4ac3549f65eefa6b094cd04876b19ed3ca7506 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 7 Jun 2021 17:14:19 -0700 Subject: [PATCH] All ScaledStdConv and StdConv uses default to using F.layernorm so that they work with PyTorch XLA. eps value tweaking is a WIP. --- timm/models/layers/std_conv.py | 52 +++++++++++++----------- timm/models/nfnet.py | 7 +++- timm/models/vision_transformer_hybrid.py | 2 +- 3 files changed, 34 insertions(+), 27 deletions(-) diff --git a/timm/models/layers/std_conv.py b/timm/models/layers/std_conv.py index b0cb1eeb..a1afc653 100644 --- a/timm/models/layers/std_conv.py +++ b/timm/models/layers/std_conv.py @@ -19,17 +19,22 @@ class StdConv2d(nn.Conv2d): """ def __init__( self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1, - groups=1, bias=False, eps=1e-5): + groups=1, bias=False, eps=1e-5, use_layernorm=True): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.eps = eps + self.use_layernorm = use_layernorm def get_weight(self): - std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) - weight = (self.weight - mean) / (std + self.eps) + if self.use_layernorm: + # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op + weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) + else: + std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + weight = (self.weight - mean) / (std + self.eps) return weight def forward(self, x): @@ -45,17 +50,22 @@ class StdConv2dSame(nn.Conv2d): """ def __init__( self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, - groups=1, bias=False, eps=1e-5): + groups=1, bias=False, eps=1e-5, use_layernorm=True): padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) super().__init__( in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.same_pad = is_dynamic self.eps = eps + self.use_layernorm = use_layernorm def get_weight(self): - std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) - weight = (self.weight - mean) / (std + self.eps) + if self.use_layernorm: + # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op + weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) + else: + std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + weight = (self.weight - mean) / (std + self.eps) return weight def forward(self, x): @@ -76,7 +86,7 @@ class ScaledStdConv2d(nn.Conv2d): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, - bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=False): + bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=True): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( @@ -84,16 +94,17 @@ class ScaledStdConv2d(nn.Conv2d): groups=groups, bias=bias) self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) - self.eps = eps ** 2 if use_layernorm else eps + self.eps = eps self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel def get_weight(self): if self.use_layernorm: - weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) + # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op + weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) else: std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) - weight = self.scale * (self.weight - mean) / (std + self.eps) - return self.gain * weight + weight = (self.weight - mean) / (std + self.eps) + return weight.mul_(self.gain * self.scale) def forward(self, x): return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) @@ -110,7 +121,7 @@ class ScaledStdConv2dSame(nn.Conv2d): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1, - bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=False): + bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=True): padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, @@ -118,24 +129,17 @@ class ScaledStdConv2dSame(nn.Conv2d): self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) self.scale = gamma * self.weight[0].numel() ** -0.5 self.same_pad = is_dynamic - self.eps = eps ** 2 if use_layernorm else eps + self.eps = eps self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel - # NOTE an alternate formulation to consider, closer to DeepMind Haiku impl but doesn't seem - # to make much numerical difference (+/- .002 to .004) in top-1 during eval. - # def get_weight(self): - # var, mean = torch.var_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) - # scale = torch.rsqrt((self.weight[0].numel() * var).clamp_(self.eps)) * self.gain - # weight = (self.weight - mean) * scale - # return self.gain * weight - def get_weight(self): if self.use_layernorm: - weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) + # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op + weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) else: std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) - weight = self.scale * (self.weight - mean) / (std + self.eps) - return self.gain * weight + weight = (self.weight - mean) / (std + self.eps) + return weight.mul_(self.gain * self.scale) def forward(self, x): if self.same_pad: diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 593796a5..584495c3 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -166,6 +166,8 @@ class NfCfg: extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models gamma_in_act: bool = False same_padding: bool = False + std_conv_eps: float = 1e-5 + std_conv_ln: bool = True # use layer-norm impl to normalize in std-conv, works in PyTorch XLA, slightly faster skipinit: bool = False # disabled by default, non-trivial performance impact zero_init_fc: bool = False act_layer: str = 'silu' @@ -482,10 +484,11 @@ class NormFreeNet(nn.Module): conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d if cfg.gamma_in_act: act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer]) - conv_layer = partial(conv_layer, eps=1e-4) # DM weights better with higher eps + conv_layer = partial(conv_layer, eps=cfg.std_conv_eps, use_layernorm=cfg.std_conv_ln) else: act_layer = get_act_layer(cfg.act_layer) - conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer]) + conv_layer = partial( + conv_layer, gamma=_nonlin_gamma[cfg.act_layer], eps=cfg.std_conv_eps, use_layernorm=cfg.std_conv_ln) attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div) diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 9e5a62b2..a32ce019 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -118,7 +118,7 @@ def _resnetv2(layers=(3, 4, 9), **kwargs): padding_same = kwargs.get('padding_same', True) if padding_same: stem_type = 'same' - conv_layer = StdConv2dSame + conv_layer = partial(StdConv2dSame, eps=1e-5) else: stem_type = '' conv_layer = StdConv2d