Merge branch 'norm_norm_norm' into bits_and_tpu

pull/1239/head
Ross Wightman 2 years ago
commit cbc4f33220

@ -120,7 +120,7 @@ default_cfgs = {
# FIXME experimental # FIXME experimental
'efficientnet_b0_gn': _cfg( 'efficientnet_b0_gn': _cfg(
url=''), url=''),
'efficientnet_b0_g8': _cfg( 'efficientnet_b0_g8_gn': _cfg(
url=''), url=''),
'efficientnet_b0_g16_evos': _cfg( 'efficientnet_b0_g16_evos': _cfg(
url=''), url=''),
@ -1389,10 +1389,11 @@ def efficientnet_b0_gn(pretrained=False, **kwargs):
@register_model @register_model
def efficientnet_b0_g8(pretrained=False, **kwargs): def efficientnet_b0_g8_gn(pretrained=False, **kwargs):
""" EfficientNet-B0 w/ group conv + BN""" """ EfficientNet-B0 w/ group conv + GroupNorm"""
model = _gen_efficientnet( model = _gen_efficientnet(
'efficientnet_b0_g8', group_size=8, pretrained=pretrained, **kwargs) 'efficientnet_b0_g8_gn', group_size=8, norm_layer=partial(GroupNormAct, group_size=8),
pretrained=pretrained, **kwargs)
return model return model

@ -19,11 +19,7 @@ def num_groups(group_size, channels):
return 1 # normal conv with 1 group return 1 # normal conv with 1 group
else: else:
# NOTE group_size == 1 -> depthwise conv # NOTE group_size == 1 -> depthwise conv
#assert channels % group_size == 0 assert channels % group_size == 0
if channels % group_size != 0:
num_groups = math.floor(channels / group_size)
print(channels, group_size, num_groups)
return int(num_groups)
return channels // group_size return channels // group_size
@ -87,7 +83,7 @@ class ConvBnAct(nn.Module):
x = self.conv(x) x = self.conv(x)
x = self.bn1(x) x = self.bn1(x)
if self.has_skip: if self.has_skip:
x = x + self.drop_path(shortcut) x = self.drop_path(x) + shortcut
return x return x
@ -131,7 +127,7 @@ class DepthwiseSeparableConv(nn.Module):
x = self.conv_pw(x) x = self.conv_pw(x)
x = self.bn2(x) x = self.bn2(x)
if self.has_skip: if self.has_skip:
x = x + self.drop_path(shortcut) x = self.drop_path(x) + shortcut
return x return x
@ -190,7 +186,7 @@ class InvertedResidual(nn.Module):
x = self.conv_pwl(x) x = self.conv_pwl(x)
x = self.bn3(x) x = self.bn3(x)
if self.has_skip: if self.has_skip:
x = x + self.drop_path(shortcut) x = self.drop_path(x) + shortcut
return x return x
@ -225,7 +221,7 @@ class CondConvResidual(InvertedResidual):
x = self.conv_pwl(x, routing_weights) x = self.conv_pwl(x, routing_weights)
x = self.bn3(x) x = self.bn3(x)
if self.has_skip: if self.has_skip:
x = x + self.drop_path(shortcut) x = self.drop_path(x) + shortcut
return x return x
@ -281,5 +277,5 @@ class EdgeResidual(nn.Module):
x = self.conv_pwl(x) x = self.conv_pwl(x)
x = self.bn2(x) x = self.bn2(x)
if self.has_skip: if self.has_skip:
x = x + self.drop_path(shortcut) x = self.drop_path(x) + shortcut
return x return x

@ -40,7 +40,7 @@ def get_bn_args_tf():
def resolve_bn_args(kwargs): def resolve_bn_args(kwargs):
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} bn_args = {}
bn_momentum = kwargs.pop('bn_momentum', None) bn_momentum = kwargs.pop('bn_momentum', None)
if bn_momentum is not None: if bn_momentum is not None:
bn_args['momentum'] = bn_momentum bn_args['momentum'] = bn_momentum

@ -47,13 +47,6 @@ def create_model(
""" """
source_name, model_name = split_model_name(model_name) source_name, model_name = split_model_name(model_name)
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])
if not is_efficientnet:
kwargs.pop('bn_tf', None)
kwargs.pop('bn_momentum', None)
kwargs.pop('bn_eps', None)
# handle backwards compat with drop_connect -> drop_path change # handle backwards compat with drop_connect -> drop_path change
drop_connect_rate = kwargs.pop('drop_connect_rate', None) drop_connect_rate = kwargs.pop('drop_connect_rate', None)
if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None: if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:

@ -13,15 +13,18 @@ Weights from original impl have been modified
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import numpy as np import math
import torch.nn as nn
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Optional, Union, Callable from typing import Optional, Union, Callable
import numpy as np
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, get_act_layer, GroupNormAct from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct
from .layers import get_act_layer, get_norm_act_layer, create_conv2d
from .registry import register_model from .registry import register_model
@ -37,6 +40,8 @@ class RegNetCfg:
stem_width: int = 32 stem_width: int = 32
downsample: Optional[str] = 'conv1x1' downsample: Optional[str] = 'conv1x1'
linear_out: bool = False linear_out: bool = False
preact: bool = False
num_features: int = 0
act_layer: Union[str, Callable] = 'relu' act_layer: Union[str, Callable] = 'relu'
norm_layer: Union[str, Callable] = 'batchnorm' norm_layer: Union[str, Callable] = 'batchnorm'
@ -75,15 +80,23 @@ model_cfgs = dict(
regnety_040s_gn=RegNetCfg( regnety_040s_gn=RegNetCfg(
w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25, w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25,
act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)), act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)),
# regnetv = 'preact regnet y'
regnetv_040=RegNetCfg(
depth=22, w0=96, wa=31.41, wm=2.24, group_size=64, se_ratio=0.25, preact=True, act_layer='silu'),
# regnetw = 'preact regnet z'
regnetw_040=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, preact=True, num_features=1536, act_layer='silu',
),
# RegNet-Z (unverified) # RegNet-Z (unverified)
regnetz_005=RegNetCfg( regnetz_005=RegNetCfg(
depth=21, w0=16, wa=10.7, wm=2.51, group_size=4, bottle_ratio=4.0, se_ratio=0.25, depth=21, w0=16, wa=10.7, wm=2.51, group_size=4, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, act_layer='silu', downsample=None, linear_out=True, num_features=1024, act_layer='silu',
), ),
regnetz_040=RegNetCfg( regnetz_040=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25, depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, act_layer='silu', downsample=None, linear_out=True, num_features=1536, act_layer='silu',
), ),
) )
@ -130,6 +143,8 @@ default_cfgs = dict(
regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'), regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'),
regnety_040s_gn=_cfg(url=''), regnety_040s_gn=_cfg(url=''),
regnetv_040=_cfg(url=''),
regnetw_040=_cfg(url=''),
regnetz_005=_cfg(url=''), regnetz_005=_cfg(url=''),
regnetz_040=_cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), regnetz_040=_cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
@ -162,15 +177,18 @@ def generate_regnet(width_slope, width_initial, width_mult, depth, q=8):
return widths, num_stages, max_stage, widths_cont return widths, num_stages, max_stage, widths_cont
def downsample_conv(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None): def downsample_conv(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None, preact=False):
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
dilation = dilation if kernel_size > 1 else 1 dilation = dilation if kernel_size > 1 else 1
return ConvNormAct( if preact:
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, apply_act=False) return create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation)
else:
return ConvNormAct(
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, apply_act=False)
def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None): def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None, preact=False):
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
avg_stride = stride if dilation == 1 else 1 avg_stride = stride if dilation == 1 else 1
@ -178,20 +196,24 @@ def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_la
if stride > 1 or dilation > 1: if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
return nn.Sequential(*[ if preact:
pool, ConvNormAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, apply_act=False)]) conv = create_conv2d(in_chs, out_chs, 1, stride=1)
else:
conv = ConvNormAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, apply_act=False)
return nn.Sequential(*[pool, conv])
def create_shortcut(downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None): def create_shortcut(
downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None, preact=False):
assert downsample_type in ('avg', 'conv1x1', '', None) assert downsample_type in ('avg', 'conv1x1', '', None)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact)
if not downsample_type: if not downsample_type:
return None # no shortcut, no downsample return None # no shortcut, no downsample
elif downsample_type == 'avg': elif downsample_type == 'avg':
return downsample_avg(in_chs, out_chs, stride=stride, dilation=dilation[0], norm_layer=norm_layer) return downsample_avg(in_chs, out_chs, **dargs)
else: else:
return downsample_conv( return downsample_conv(in_chs, out_chs, kernel_size=kernel_size, **dargs)
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation[0], norm_layer=norm_layer)
else: else:
return nn.Identity() # identity shortcut (no downsample) return nn.Identity() # identity shortcut (no downsample)
@ -203,9 +225,10 @@ class Bottleneck(nn.Module):
after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels. after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels.
""" """
def __init__(self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25, def __init__(
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
drop_block=None, drop_path_rate=0.): downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_block=None, drop_path_rate=0.):
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
act_layer = get_act_layer(act_layer) act_layer = get_act_layer(act_layer)
bottleneck_chs = int(round(out_chs * bottle_ratio)) bottleneck_chs = int(round(out_chs * bottle_ratio))
@ -238,22 +261,68 @@ class Bottleneck(nn.Module):
if self.downsample is not None: if self.downsample is not None:
# NOTE stuck with downsample as the attr name due to weight compatibility # NOTE stuck with downsample as the attr name due to weight compatibility
# now represents the shortcut, no shortcut if None, and non-downsample shortcut == nn.Identity() # now represents the shortcut, no shortcut if None, and non-downsample shortcut == nn.Identity()
x = x + self.drop_path(self.downsample(shortcut)) x = self.drop_path(x) + self.downsample(shortcut)
x = self.act3(x) x = self.act3(x)
return x return x
class PreBottleneck(nn.Module):
""" RegNet Bottleneck
This is almost exactly the same as a ResNet Bottlneck. The main difference is the SE block is moved from
after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels.
"""
def __init__(
self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_block=None, drop_path_rate=0.):
super(PreBottleneck, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
bottleneck_chs = int(round(out_chs * bottle_ratio))
groups = bottleneck_chs // group_size
self.norm1 = norm_act_layer(in_chs)
self.conv1 = create_conv2d(in_chs, bottleneck_chs, kernel_size=1)
self.norm2 = norm_act_layer(bottleneck_chs)
self.conv2 = create_conv2d(
bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation[0], groups=groups)
if se_ratio:
se_channels = int(round(in_chs * se_ratio))
self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer)
else:
self.se = nn.Identity()
self.norm3 = norm_act_layer(bottleneck_chs)
self.conv3 = create_conv2d(bottleneck_chs, out_chs, kernel_size=1)
self.downsample = create_shortcut(downsample, in_chs, out_chs, 1, stride, dilation, preact=True)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def zero_init_last(self):
pass
def forward(self, x):
x = self.norm1(x)
shortcut = x
x = self.conv1(x)
x = self.norm2(x)
x = self.conv2(x)
x = self.se(x)
x = self.norm3(x)
x = self.conv3(x)
if self.downsample is not None:
# NOTE stuck with downsample as the attr name due to weight compatibility
# now represents the shortcut, no shortcut if None, and non-downsample shortcut == nn.Identity()
x = self.drop_path(x) + self.downsample(shortcut)
return x
class RegStage(nn.Module): class RegStage(nn.Module):
"""Stage (sequence of blocks w/ the same output shape).""" """Stage (sequence of blocks w/ the same output shape)."""
def __init__( def __init__(
self, depth, in_chs, out_chs, stride, dilation, bottle_ratio=1.0, group_size=8, block_fn=Bottleneck, self, depth, in_chs, out_chs, stride, dilation,
se_ratio=0., downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rates=None, block_fn=Bottleneck, **block_kwargs):
drop_path_rates=None, drop_block=None):
super(RegStage, self).__init__() super(RegStage, self).__init__()
block_kwargs = dict(
bottle_ratio=bottle_ratio, group_size=group_size, se_ratio=se_ratio, downsample=downsample,
linear_out=linear_out, act_layer=act_layer, norm_layer=norm_layer, drop_block=drop_block)
first_dilation = 1 if dilation in (1, 2) else 2 first_dilation = 1 if dilation in (1, 2) else 2
for i in range(depth): for i in range(depth):
block_stride = stride if i == 0 else 1 block_stride = stride if i == 0 else 1
@ -291,30 +360,40 @@ class RegNet(nn.Module):
# Construct the stem # Construct the stem
stem_width = cfg.stem_width stem_width = cfg.stem_width
self.stem = ConvNormAct(in_chans, stem_width, 3, stride=2, act_layer=cfg.act_layer, norm_layer=cfg.norm_layer) na_args = dict(act_layer=cfg.act_layer, norm_layer=cfg.norm_layer)
if cfg.preact:
self.stem = create_conv2d(in_chans, stem_width, 3, stride=2)
else:
self.stem = ConvNormAct(in_chans, stem_width, 3, stride=2, **na_args)
self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')] self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')]
# Construct the stages # Construct the stages
prev_width = stem_width prev_width = stem_width
curr_stride = 2 curr_stride = 2
stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate) per_stage_args, common_args = self._get_stage_args(
for i, stage_args in enumerate(stage_params): cfg, output_stride=output_stride, drop_path_rate=drop_path_rate)
block_fn = PreBottleneck if cfg.preact else Bottleneck
for i, stage_args in enumerate(per_stage_args):
stage_name = "s{}".format(i + 1) stage_name = "s{}".format(i + 1)
self.add_module(stage_name, RegStage( self.add_module(stage_name, RegStage(in_chs=prev_width, block_fn=block_fn, **stage_args, **common_args))
in_chs=prev_width, se_ratio=cfg.se_ratio, downsample=cfg.downsample, linear_out=cfg.linear_out,
act_layer=cfg.act_layer, norm_layer=cfg.norm_layer, **stage_args))
prev_width = stage_args['out_chs'] prev_width = stage_args['out_chs']
curr_stride *= stage_args['stride'] curr_stride *= stage_args['stride']
self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)] self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)]
# Construct the head # Construct the head
self.num_features = prev_width if cfg.num_features:
self.final_conv = ConvNormAct(prev_width, cfg.num_features, kernel_size=1, **na_args)
self.num_features = cfg.num_features
else:
final_act = cfg.linear_out or cfg.preact
self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity()
self.num_features = prev_width
self.head = ClassifierHead( self.head = ClassifierHead(
in_chs=prev_width, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
def _get_stage_params(self, cfg: RegNetCfg, default_stride=2, output_stride=32, drop_path_rate=0.): def _get_stage_args(self, cfg: RegNetCfg, default_stride=2, output_stride=32, drop_path_rate=0.):
# Generate RegNet ws per block # Generate RegNet ws per block
widths, num_stages, _, _ = generate_regnet(cfg.wa, cfg.w0, cfg.wm, cfg.depth) widths, num_stages, _, _ = generate_regnet(cfg.wa, cfg.w0, cfg.wm, cfg.depth)
@ -341,12 +420,15 @@ class RegNet(nn.Module):
# Adjust the compatibility of ws and gws # Adjust the compatibility of ws and gws
stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups) stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups)
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_size', 'drop_path_rates'] arg_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_size', 'drop_path_rates']
stage_params = [ per_stage_args = [
dict(zip(param_names, params)) for params in dict(zip(arg_names, params)) for params in
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups, zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups,
stage_dpr)] stage_dpr)]
return stage_params common_args = dict(
downsample=cfg.downsample, se_ratio=cfg.se_ratio, linear_out=cfg.linear_out,
act_layer=cfg.act_layer, norm_layer=cfg.norm_layer)
return per_stage_args, common_args
def get_classifier(self): def get_classifier(self):
return self.head.fc return self.head.fc
@ -367,14 +449,16 @@ class RegNet(nn.Module):
def _init_weights(module, name='', zero_init_last=False): def _init_weights(module, name='', zero_init_last=False):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
elif isinstance(module, nn.BatchNorm2d): fan_out //= module.groups
nn.init.ones_(module.weight) module.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
nn.init.zeros_(module.bias) if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Linear): elif isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.01) nn.init.normal_(module.weight, mean=0.0, std=0.01)
nn.init.zeros_(module.bias) if module.bias is not None:
elif hasattr(module, 'zero_init_last'): nn.init.zeros_(module.bias)
elif zero_init_last and hasattr(module, 'zero_init_last'):
module.zero_init_last() module.zero_init_last()
@ -545,13 +629,25 @@ def regnety_040s_gn(pretrained=False, **kwargs):
return _create_regnet('regnety_040s_gn', pretrained, **kwargs) return _create_regnet('regnety_040s_gn', pretrained, **kwargs)
@register_model
def regnetv_040(pretrained=False, **kwargs):
""""""
return _create_regnet('regnetv_040', pretrained, **kwargs)
@register_model
def regnetw_040(pretrained=False, **kwargs):
""""""
return _create_regnet('regnetw_040', pretrained, **kwargs)
@register_model @register_model
def regnetz_005(pretrained=False, **kwargs): def regnetz_005(pretrained=False, **kwargs):
"""RegNetZ-500MF """RegNetZ-500MF
NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py
but it's not clear it is equivalent to paper model as not detailed in the paper. but it's not clear it is equivalent to paper model as not detailed in the paper.
""" """
return _create_regnet('regnetz_005', pretrained, **kwargs) return _create_regnet('regnetz_005', pretrained, zero_init_last=False, **kwargs)
@register_model @register_model
@ -560,4 +656,4 @@ def regnetz_040(pretrained=False, **kwargs):
NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py
but it's not clear it is equivalent to paper model as not detailed in the paper. but it's not clear it is equivalent to paper model as not detailed in the paper.
""" """
return _create_regnet('regnetz_040', pretrained, **kwargs) return _create_regnet('regnetz_040', pretrained, zero_init_last=False, **kwargs)

Loading…
Cancel
Save