|
|
|
@ -13,15 +13,18 @@ Weights from original impl have been modified
|
|
|
|
|
|
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
|
"""
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import math
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Optional, Union, Callable
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from .helpers import build_model_with_cfg, named_apply
|
|
|
|
|
from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, get_act_layer, GroupNormAct
|
|
|
|
|
from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct
|
|
|
|
|
from .layers import get_act_layer, get_norm_act_layer, create_conv2d
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -37,6 +40,8 @@ class RegNetCfg:
|
|
|
|
|
stem_width: int = 32
|
|
|
|
|
downsample: Optional[str] = 'conv1x1'
|
|
|
|
|
linear_out: bool = False
|
|
|
|
|
preact: bool = False
|
|
|
|
|
num_features: int = 0
|
|
|
|
|
act_layer: Union[str, Callable] = 'relu'
|
|
|
|
|
norm_layer: Union[str, Callable] = 'batchnorm'
|
|
|
|
|
|
|
|
|
@ -75,15 +80,23 @@ model_cfgs = dict(
|
|
|
|
|
regnety_040s_gn=RegNetCfg(
|
|
|
|
|
w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25,
|
|
|
|
|
act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)),
|
|
|
|
|
# regnetv = 'preact regnet y'
|
|
|
|
|
regnetv_040=RegNetCfg(
|
|
|
|
|
depth=22, w0=96, wa=31.41, wm=2.24, group_size=64, se_ratio=0.25, preact=True, act_layer='silu'),
|
|
|
|
|
# regnetw = 'preact regnet z'
|
|
|
|
|
regnetw_040=RegNetCfg(
|
|
|
|
|
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
|
|
|
|
|
downsample=None, preact=True, num_features=1536, act_layer='silu',
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
# RegNet-Z (unverified)
|
|
|
|
|
regnetz_005=RegNetCfg(
|
|
|
|
|
depth=21, w0=16, wa=10.7, wm=2.51, group_size=4, bottle_ratio=4.0, se_ratio=0.25,
|
|
|
|
|
downsample=None, linear_out=True, act_layer='silu',
|
|
|
|
|
downsample=None, linear_out=True, num_features=1024, act_layer='silu',
|
|
|
|
|
),
|
|
|
|
|
regnetz_040=RegNetCfg(
|
|
|
|
|
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
|
|
|
|
|
downsample=None, linear_out=True, act_layer='silu',
|
|
|
|
|
downsample=None, linear_out=True, num_features=1536, act_layer='silu',
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -130,6 +143,8 @@ default_cfgs = dict(
|
|
|
|
|
regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'),
|
|
|
|
|
|
|
|
|
|
regnety_040s_gn=_cfg(url=''),
|
|
|
|
|
regnetv_040=_cfg(url=''),
|
|
|
|
|
regnetw_040=_cfg(url=''),
|
|
|
|
|
|
|
|
|
|
regnetz_005=_cfg(url=''),
|
|
|
|
|
regnetz_040=_cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
@ -162,15 +177,18 @@ def generate_regnet(width_slope, width_initial, width_mult, depth, q=8):
|
|
|
|
|
return widths, num_stages, max_stage, widths_cont
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def downsample_conv(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None):
|
|
|
|
|
def downsample_conv(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None, preact=False):
|
|
|
|
|
norm_layer = norm_layer or nn.BatchNorm2d
|
|
|
|
|
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
|
|
|
|
|
dilation = dilation if kernel_size > 1 else 1
|
|
|
|
|
return ConvNormAct(
|
|
|
|
|
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, apply_act=False)
|
|
|
|
|
if preact:
|
|
|
|
|
return create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation)
|
|
|
|
|
else:
|
|
|
|
|
return ConvNormAct(
|
|
|
|
|
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, apply_act=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None):
|
|
|
|
|
def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None, preact=False):
|
|
|
|
|
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
|
|
|
|
|
norm_layer = norm_layer or nn.BatchNorm2d
|
|
|
|
|
avg_stride = stride if dilation == 1 else 1
|
|
|
|
@ -178,20 +196,24 @@ def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_la
|
|
|
|
|
if stride > 1 or dilation > 1:
|
|
|
|
|
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
|
|
|
|
pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
|
|
|
|
return nn.Sequential(*[
|
|
|
|
|
pool, ConvNormAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, apply_act=False)])
|
|
|
|
|
if preact:
|
|
|
|
|
conv = create_conv2d(in_chs, out_chs, 1, stride=1)
|
|
|
|
|
else:
|
|
|
|
|
conv = ConvNormAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, apply_act=False)
|
|
|
|
|
return nn.Sequential(*[pool, conv])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_shortcut(downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None):
|
|
|
|
|
def create_shortcut(
|
|
|
|
|
downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None, preact=False):
|
|
|
|
|
assert downsample_type in ('avg', 'conv1x1', '', None)
|
|
|
|
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
|
|
|
|
dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact)
|
|
|
|
|
if not downsample_type:
|
|
|
|
|
return None # no shortcut, no downsample
|
|
|
|
|
elif downsample_type == 'avg':
|
|
|
|
|
return downsample_avg(in_chs, out_chs, stride=stride, dilation=dilation[0], norm_layer=norm_layer)
|
|
|
|
|
return downsample_avg(in_chs, out_chs, **dargs)
|
|
|
|
|
else:
|
|
|
|
|
return downsample_conv(
|
|
|
|
|
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation[0], norm_layer=norm_layer)
|
|
|
|
|
return downsample_conv(in_chs, out_chs, kernel_size=kernel_size, **dargs)
|
|
|
|
|
else:
|
|
|
|
|
return nn.Identity() # identity shortcut (no downsample)
|
|
|
|
|
|
|
|
|
@ -203,9 +225,10 @@ class Bottleneck(nn.Module):
|
|
|
|
|
after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
|
|
|
|
|
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
|
|
|
|
drop_block=None, drop_path_rate=0.):
|
|
|
|
|
def __init__(
|
|
|
|
|
self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
|
|
|
|
|
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
|
|
|
|
drop_block=None, drop_path_rate=0.):
|
|
|
|
|
super(Bottleneck, self).__init__()
|
|
|
|
|
act_layer = get_act_layer(act_layer)
|
|
|
|
|
bottleneck_chs = int(round(out_chs * bottle_ratio))
|
|
|
|
@ -238,22 +261,68 @@ class Bottleneck(nn.Module):
|
|
|
|
|
if self.downsample is not None:
|
|
|
|
|
# NOTE stuck with downsample as the attr name due to weight compatibility
|
|
|
|
|
# now represents the shortcut, no shortcut if None, and non-downsample shortcut == nn.Identity()
|
|
|
|
|
x = x + self.drop_path(self.downsample(shortcut))
|
|
|
|
|
x = self.drop_path(x) + self.downsample(shortcut)
|
|
|
|
|
x = self.act3(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PreBottleneck(nn.Module):
|
|
|
|
|
""" RegNet Bottleneck
|
|
|
|
|
|
|
|
|
|
This is almost exactly the same as a ResNet Bottlneck. The main difference is the SE block is moved from
|
|
|
|
|
after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
|
|
|
|
|
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
|
|
|
|
drop_block=None, drop_path_rate=0.):
|
|
|
|
|
super(PreBottleneck, self).__init__()
|
|
|
|
|
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
|
|
|
|
bottleneck_chs = int(round(out_chs * bottle_ratio))
|
|
|
|
|
groups = bottleneck_chs // group_size
|
|
|
|
|
|
|
|
|
|
self.norm1 = norm_act_layer(in_chs)
|
|
|
|
|
self.conv1 = create_conv2d(in_chs, bottleneck_chs, kernel_size=1)
|
|
|
|
|
self.norm2 = norm_act_layer(bottleneck_chs)
|
|
|
|
|
self.conv2 = create_conv2d(
|
|
|
|
|
bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation[0], groups=groups)
|
|
|
|
|
if se_ratio:
|
|
|
|
|
se_channels = int(round(in_chs * se_ratio))
|
|
|
|
|
self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer)
|
|
|
|
|
else:
|
|
|
|
|
self.se = nn.Identity()
|
|
|
|
|
self.norm3 = norm_act_layer(bottleneck_chs)
|
|
|
|
|
self.conv3 = create_conv2d(bottleneck_chs, out_chs, kernel_size=1)
|
|
|
|
|
self.downsample = create_shortcut(downsample, in_chs, out_chs, 1, stride, dilation, preact=True)
|
|
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def zero_init_last(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.norm1(x)
|
|
|
|
|
shortcut = x
|
|
|
|
|
x = self.conv1(x)
|
|
|
|
|
x = self.norm2(x)
|
|
|
|
|
x = self.conv2(x)
|
|
|
|
|
x = self.se(x)
|
|
|
|
|
x = self.norm3(x)
|
|
|
|
|
x = self.conv3(x)
|
|
|
|
|
if self.downsample is not None:
|
|
|
|
|
# NOTE stuck with downsample as the attr name due to weight compatibility
|
|
|
|
|
# now represents the shortcut, no shortcut if None, and non-downsample shortcut == nn.Identity()
|
|
|
|
|
x = self.drop_path(x) + self.downsample(shortcut)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RegStage(nn.Module):
|
|
|
|
|
"""Stage (sequence of blocks w/ the same output shape)."""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self, depth, in_chs, out_chs, stride, dilation, bottle_ratio=1.0, group_size=8, block_fn=Bottleneck,
|
|
|
|
|
se_ratio=0., downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
|
|
|
|
drop_path_rates=None, drop_block=None):
|
|
|
|
|
self, depth, in_chs, out_chs, stride, dilation,
|
|
|
|
|
drop_path_rates=None, block_fn=Bottleneck, **block_kwargs):
|
|
|
|
|
super(RegStage, self).__init__()
|
|
|
|
|
block_kwargs = dict(
|
|
|
|
|
bottle_ratio=bottle_ratio, group_size=group_size, se_ratio=se_ratio, downsample=downsample,
|
|
|
|
|
linear_out=linear_out, act_layer=act_layer, norm_layer=norm_layer, drop_block=drop_block)
|
|
|
|
|
first_dilation = 1 if dilation in (1, 2) else 2
|
|
|
|
|
for i in range(depth):
|
|
|
|
|
block_stride = stride if i == 0 else 1
|
|
|
|
@ -291,30 +360,40 @@ class RegNet(nn.Module):
|
|
|
|
|
|
|
|
|
|
# Construct the stem
|
|
|
|
|
stem_width = cfg.stem_width
|
|
|
|
|
self.stem = ConvNormAct(in_chans, stem_width, 3, stride=2, act_layer=cfg.act_layer, norm_layer=cfg.norm_layer)
|
|
|
|
|
na_args = dict(act_layer=cfg.act_layer, norm_layer=cfg.norm_layer)
|
|
|
|
|
if cfg.preact:
|
|
|
|
|
self.stem = create_conv2d(in_chans, stem_width, 3, stride=2)
|
|
|
|
|
else:
|
|
|
|
|
self.stem = ConvNormAct(in_chans, stem_width, 3, stride=2, **na_args)
|
|
|
|
|
self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')]
|
|
|
|
|
|
|
|
|
|
# Construct the stages
|
|
|
|
|
prev_width = stem_width
|
|
|
|
|
curr_stride = 2
|
|
|
|
|
stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate)
|
|
|
|
|
for i, stage_args in enumerate(stage_params):
|
|
|
|
|
per_stage_args, common_args = self._get_stage_args(
|
|
|
|
|
cfg, output_stride=output_stride, drop_path_rate=drop_path_rate)
|
|
|
|
|
block_fn = PreBottleneck if cfg.preact else Bottleneck
|
|
|
|
|
for i, stage_args in enumerate(per_stage_args):
|
|
|
|
|
stage_name = "s{}".format(i + 1)
|
|
|
|
|
self.add_module(stage_name, RegStage(
|
|
|
|
|
in_chs=prev_width, se_ratio=cfg.se_ratio, downsample=cfg.downsample, linear_out=cfg.linear_out,
|
|
|
|
|
act_layer=cfg.act_layer, norm_layer=cfg.norm_layer, **stage_args))
|
|
|
|
|
self.add_module(stage_name, RegStage(in_chs=prev_width, block_fn=block_fn, **stage_args, **common_args))
|
|
|
|
|
prev_width = stage_args['out_chs']
|
|
|
|
|
curr_stride *= stage_args['stride']
|
|
|
|
|
self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)]
|
|
|
|
|
|
|
|
|
|
# Construct the head
|
|
|
|
|
self.num_features = prev_width
|
|
|
|
|
if cfg.num_features:
|
|
|
|
|
self.final_conv = ConvNormAct(prev_width, cfg.num_features, kernel_size=1, **na_args)
|
|
|
|
|
self.num_features = cfg.num_features
|
|
|
|
|
else:
|
|
|
|
|
final_act = cfg.linear_out or cfg.preact
|
|
|
|
|
self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity()
|
|
|
|
|
self.num_features = prev_width
|
|
|
|
|
self.head = ClassifierHead(
|
|
|
|
|
in_chs=prev_width, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
|
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
|
|
|
|
|
|
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
|
|
|
|
|
|
|
|
|
|
def _get_stage_params(self, cfg: RegNetCfg, default_stride=2, output_stride=32, drop_path_rate=0.):
|
|
|
|
|
def _get_stage_args(self, cfg: RegNetCfg, default_stride=2, output_stride=32, drop_path_rate=0.):
|
|
|
|
|
# Generate RegNet ws per block
|
|
|
|
|
widths, num_stages, _, _ = generate_regnet(cfg.wa, cfg.w0, cfg.wm, cfg.depth)
|
|
|
|
|
|
|
|
|
@ -341,12 +420,15 @@ class RegNet(nn.Module):
|
|
|
|
|
|
|
|
|
|
# Adjust the compatibility of ws and gws
|
|
|
|
|
stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups)
|
|
|
|
|
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_size', 'drop_path_rates']
|
|
|
|
|
stage_params = [
|
|
|
|
|
dict(zip(param_names, params)) for params in
|
|
|
|
|
arg_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_size', 'drop_path_rates']
|
|
|
|
|
per_stage_args = [
|
|
|
|
|
dict(zip(arg_names, params)) for params in
|
|
|
|
|
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups,
|
|
|
|
|
stage_dpr)]
|
|
|
|
|
return stage_params
|
|
|
|
|
common_args = dict(
|
|
|
|
|
downsample=cfg.downsample, se_ratio=cfg.se_ratio, linear_out=cfg.linear_out,
|
|
|
|
|
act_layer=cfg.act_layer, norm_layer=cfg.norm_layer)
|
|
|
|
|
return per_stage_args, common_args
|
|
|
|
|
|
|
|
|
|
def get_classifier(self):
|
|
|
|
|
return self.head.fc
|
|
|
|
@ -367,14 +449,16 @@ class RegNet(nn.Module):
|
|
|
|
|
|
|
|
|
|
def _init_weights(module, name='', zero_init_last=False):
|
|
|
|
|
if isinstance(module, nn.Conv2d):
|
|
|
|
|
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
|
|
|
|
elif isinstance(module, nn.BatchNorm2d):
|
|
|
|
|
nn.init.ones_(module.weight)
|
|
|
|
|
nn.init.zeros_(module.bias)
|
|
|
|
|
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
|
|
|
|
|
fan_out //= module.groups
|
|
|
|
|
module.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
|
|
|
|
if module.bias is not None:
|
|
|
|
|
module.bias.data.zero_()
|
|
|
|
|
elif isinstance(module, nn.Linear):
|
|
|
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.01)
|
|
|
|
|
nn.init.zeros_(module.bias)
|
|
|
|
|
elif hasattr(module, 'zero_init_last'):
|
|
|
|
|
if module.bias is not None:
|
|
|
|
|
nn.init.zeros_(module.bias)
|
|
|
|
|
elif zero_init_last and hasattr(module, 'zero_init_last'):
|
|
|
|
|
module.zero_init_last()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -545,13 +629,25 @@ def regnety_040s_gn(pretrained=False, **kwargs):
|
|
|
|
|
return _create_regnet('regnety_040s_gn', pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def regnetv_040(pretrained=False, **kwargs):
|
|
|
|
|
""""""
|
|
|
|
|
return _create_regnet('regnetv_040', pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def regnetw_040(pretrained=False, **kwargs):
|
|
|
|
|
""""""
|
|
|
|
|
return _create_regnet('regnetw_040', pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def regnetz_005(pretrained=False, **kwargs):
|
|
|
|
|
"""RegNetZ-500MF
|
|
|
|
|
NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py
|
|
|
|
|
but it's not clear it is equivalent to paper model as not detailed in the paper.
|
|
|
|
|
"""
|
|
|
|
|
return _create_regnet('regnetz_005', pretrained, **kwargs)
|
|
|
|
|
return _create_regnet('regnetz_005', pretrained, zero_init_last=False, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@ -560,4 +656,4 @@ def regnetz_040(pretrained=False, **kwargs):
|
|
|
|
|
NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py
|
|
|
|
|
but it's not clear it is equivalent to paper model as not detailed in the paper.
|
|
|
|
|
"""
|
|
|
|
|
return _create_regnet('regnetz_040', pretrained, **kwargs)
|
|
|
|
|
return _create_regnet('regnetz_040', pretrained, zero_init_last=False, **kwargs)
|
|
|
|
|