Improve RegVGG block identity/vs non for clariy and fix attn usage. Add comments.

pull/419/head
Ross Wightman 3 years ago
parent 0356e773f5
commit 6853b07bbd

@ -32,7 +32,6 @@ from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
@ -443,7 +442,7 @@ class RepVggBlock(nn.Module):
Adapted from impl at https://github.com/DingXiaoH/RepVGG
This version does not currently support the deploy optimization. It is currently fixed in 'train' model.
This version does not currently support the deploy optimization. It is currently fixed in 'train' mode.
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
@ -461,8 +460,8 @@ class RepVggBlock(nn.Module):
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, apply_act=False, **layer_args)
self.conv_1x1 = ConvBnAct(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False, **layer_args)
self.attn = None if attn_layer is None else attn_layer(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.attn = nn.Identity() if attn_layer is None else attn_layer(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
self.act = act_layer(inplace=True)
def init_weights(self, zero_init_last_bn=False):
@ -474,14 +473,14 @@ class RepVggBlock(nn.Module):
def forward(self, x):
if self.identity is None:
identity = 0
x = self.conv_1x1(x) + self.conv_kxk(x)
else:
identity = self.identity(x)
x = self.conv_1x1(x) + self.conv_kxk(x)
if self.attn is not None:
x = self.attn(x)
x = self.drop_path(x)
x = self.act(x + identity)
x = self.conv_1x1(x) + self.conv_kxk(x)
x = self.drop_path(x) # not in the paper / official impl, experimental
x = x + identity
x = self.attn(x) # no attn in the paper / official impl, experimental
x = self.act(x)
return x
@ -654,54 +653,87 @@ def _create_byobnet(variant, pretrained=False, **kwargs):
@register_model
def gernet_l(pretrained=False, **kwargs):
""" GEResNet-Large (GENet-Large from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_l', pretrained=pretrained, **kwargs)
@register_model
def gernet_m(pretrained=False, **kwargs):
""" GEResNet-Medium (GENet-Normal from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_m', pretrained=pretrained, **kwargs)
@register_model
def gernet_s(pretrained=False, **kwargs):
""" EResNet-Small (GENet-Small from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs)
@register_model
def repvgg_a2(pretrained=False, **kwargs):
""" RepVGG-A2
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_a2', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b0(pretrained=False, **kwargs):
""" RepVGG-B0
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b0', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b1(pretrained=False, **kwargs):
""" RepVGG-B1
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b1', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b1g4(pretrained=False, **kwargs):
""" RepVGG-B1g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b1g4', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b2(pretrained=False, **kwargs):
""" RepVGG-B2
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b2', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b2g4(pretrained=False, **kwargs):
""" RepVGG-B2g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b2g4', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b3(pretrained=False, **kwargs):
""" RepVGG-B3
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b3', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b3g4(pretrained=False, **kwargs):
""" RepVGG-B3g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs)

Loading…
Cancel
Save