|
|
@ -17,8 +17,8 @@ from torch import nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from timm.models.layers import trunc_normal_tf_
|
|
|
|
from .fx_features import register_notrace_module
|
|
|
|
from timm.models.layers import DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d
|
|
|
|
from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d
|
|
|
|
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
|
|
|
|
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
|
|
|
|
from .registry import register_model
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
|
@ -53,6 +53,7 @@ default_cfgs = dict(
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
|
|
|
|
class PositionalEncodingFourier(nn.Module):
|
|
|
|
class PositionalEncodingFourier(nn.Module):
|
|
|
|
def __init__(self, hidden_dim=32, dim=768, temperature=10000):
|
|
|
|
def __init__(self, hidden_dim=32, dim=768, temperature=10000):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
@ -349,6 +350,7 @@ class EdgeNeXt(nn.Module):
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
norm_layer = partial(LayerNorm2d, eps=1e-6)
|
|
|
|
norm_layer = partial(LayerNorm2d, eps=1e-6)
|
|
|
|
norm_layer_cl = partial(nn.LayerNorm, eps=1e-6)
|
|
|
|
norm_layer_cl = partial(nn.LayerNorm, eps=1e-6)
|
|
|
|
|
|
|
|
self.feature_info = []
|
|
|
|
|
|
|
|
|
|
|
|
assert stem_type in ('patch', 'overlap')
|
|
|
|
assert stem_type in ('patch', 'overlap')
|
|
|
|
if stem_type == 'patch':
|
|
|
|
if stem_type == 'patch':
|
|
|
@ -362,14 +364,18 @@ class EdgeNeXt(nn.Module):
|
|
|
|
norm_layer(dims[0]),
|
|
|
|
norm_layer(dims[0]),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
curr_stride = 4
|
|
|
|
stages = []
|
|
|
|
stages = []
|
|
|
|
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
|
|
|
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
|
|
|
in_chs = dims[0]
|
|
|
|
in_chs = dims[0]
|
|
|
|
for i in range(4):
|
|
|
|
for i in range(4):
|
|
|
|
|
|
|
|
stride = 2 if curr_stride == 2 or i > 0 else 1
|
|
|
|
|
|
|
|
# FIXME support dilation / output_stride
|
|
|
|
|
|
|
|
curr_stride *= stride
|
|
|
|
stages.append(EdgeNeXtStage(
|
|
|
|
stages.append(EdgeNeXtStage(
|
|
|
|
in_chs=in_chs,
|
|
|
|
in_chs=in_chs,
|
|
|
|
out_chs=dims[i],
|
|
|
|
out_chs=dims[i],
|
|
|
|
stride=2 if i > 0 else 1,
|
|
|
|
stride=stride,
|
|
|
|
depth=depths[i],
|
|
|
|
depth=depths[i],
|
|
|
|
num_global_blocks=global_block_counts[i],
|
|
|
|
num_global_blocks=global_block_counts[i],
|
|
|
|
num_heads=heads[i],
|
|
|
|
num_heads=heads[i],
|
|
|
@ -385,7 +391,10 @@ class EdgeNeXt(nn.Module):
|
|
|
|
norm_layer_cl=norm_layer_cl,
|
|
|
|
norm_layer_cl=norm_layer_cl,
|
|
|
|
act_layer=act_layer,
|
|
|
|
act_layer=act_layer,
|
|
|
|
))
|
|
|
|
))
|
|
|
|
|
|
|
|
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
|
|
|
|
in_chs = dims[i]
|
|
|
|
in_chs = dims[i]
|
|
|
|
|
|
|
|
self.feature_info += [dict(num_chs=in_chs, reduction=curr_stride, module=f'stages.{i}')]
|
|
|
|
|
|
|
|
|
|
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
|
|
|
|
|
|
|
|
self.num_features = dims[-1]
|
|
|
|
self.num_features = dims[-1]
|
|
|
|