Improve RegVGG block identity/vs non for clariy and fix attn usage. Add comments.

pull/419/head
Ross Wightman 4 years ago
parent 0356e773f5
commit 6853b07bbd

@ -32,7 +32,6 @@ from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
@ -443,7 +442,7 @@ class RepVggBlock(nn.Module):
Adapted from impl at https://github.com/DingXiaoH/RepVGG Adapted from impl at https://github.com/DingXiaoH/RepVGG
This version does not currently support the deploy optimization. It is currently fixed in 'train' model. This version does not currently support the deploy optimization. It is currently fixed in 'train' mode.
""" """
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
@ -461,8 +460,8 @@ class RepVggBlock(nn.Module):
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0], in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, apply_act=False, **layer_args) groups=groups, drop_block=drop_block, apply_act=False, **layer_args)
self.conv_1x1 = ConvBnAct(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False, **layer_args) self.conv_1x1 = ConvBnAct(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False, **layer_args)
self.attn = None if attn_layer is None else attn_layer(out_chs) self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
self.act = act_layer(inplace=True) self.act = act_layer(inplace=True)
def init_weights(self, zero_init_last_bn=False): def init_weights(self, zero_init_last_bn=False):
@ -474,14 +473,14 @@ class RepVggBlock(nn.Module):
def forward(self, x): def forward(self, x):
if self.identity is None: if self.identity is None:
identity = 0 x = self.conv_1x1(x) + self.conv_kxk(x)
else: else:
identity = self.identity(x) identity = self.identity(x)
x = self.conv_1x1(x) + self.conv_kxk(x) x = self.conv_1x1(x) + self.conv_kxk(x)
if self.attn is not None: x = self.drop_path(x) # not in the paper / official impl, experimental
x = self.attn(x) x = x + identity
x = self.drop_path(x) x = self.attn(x) # no attn in the paper / official impl, experimental
x = self.act(x + identity) x = self.act(x)
return x return x
@ -654,54 +653,87 @@ def _create_byobnet(variant, pretrained=False, **kwargs):
@register_model @register_model
def gernet_l(pretrained=False, **kwargs): def gernet_l(pretrained=False, **kwargs):
""" GEResNet-Large (GENet-Large from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_l', pretrained=pretrained, **kwargs) return _create_byobnet('gernet_l', pretrained=pretrained, **kwargs)
@register_model @register_model
def gernet_m(pretrained=False, **kwargs): def gernet_m(pretrained=False, **kwargs):
""" GEResNet-Medium (GENet-Normal from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_m', pretrained=pretrained, **kwargs) return _create_byobnet('gernet_m', pretrained=pretrained, **kwargs)
@register_model @register_model
def gernet_s(pretrained=False, **kwargs): def gernet_s(pretrained=False, **kwargs):
""" EResNet-Small (GENet-Small from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs) return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs)
@register_model @register_model
def repvgg_a2(pretrained=False, **kwargs): def repvgg_a2(pretrained=False, **kwargs):
""" RepVGG-A2
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_a2', pretrained=pretrained, **kwargs) return _create_byobnet('repvgg_a2', pretrained=pretrained, **kwargs)
@register_model @register_model
def repvgg_b0(pretrained=False, **kwargs): def repvgg_b0(pretrained=False, **kwargs):
""" RepVGG-B0
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b0', pretrained=pretrained, **kwargs) return _create_byobnet('repvgg_b0', pretrained=pretrained, **kwargs)
@register_model @register_model
def repvgg_b1(pretrained=False, **kwargs): def repvgg_b1(pretrained=False, **kwargs):
""" RepVGG-B1
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b1', pretrained=pretrained, **kwargs) return _create_byobnet('repvgg_b1', pretrained=pretrained, **kwargs)
@register_model @register_model
def repvgg_b1g4(pretrained=False, **kwargs): def repvgg_b1g4(pretrained=False, **kwargs):
""" RepVGG-B1g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b1g4', pretrained=pretrained, **kwargs) return _create_byobnet('repvgg_b1g4', pretrained=pretrained, **kwargs)
@register_model @register_model
def repvgg_b2(pretrained=False, **kwargs): def repvgg_b2(pretrained=False, **kwargs):
""" RepVGG-B2
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b2', pretrained=pretrained, **kwargs) return _create_byobnet('repvgg_b2', pretrained=pretrained, **kwargs)
@register_model @register_model
def repvgg_b2g4(pretrained=False, **kwargs): def repvgg_b2g4(pretrained=False, **kwargs):
""" RepVGG-B2g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b2g4', pretrained=pretrained, **kwargs) return _create_byobnet('repvgg_b2g4', pretrained=pretrained, **kwargs)
@register_model @register_model
def repvgg_b3(pretrained=False, **kwargs): def repvgg_b3(pretrained=False, **kwargs):
""" RepVGG-B3
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b3', pretrained=pretrained, **kwargs) return _create_byobnet('repvgg_b3', pretrained=pretrained, **kwargs)
@register_model @register_model
def repvgg_b3g4(pretrained=False, **kwargs): def repvgg_b3g4(pretrained=False, **kwargs):
""" RepVGG-B3g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs) return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs)

Loading…
Cancel
Save