diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index 0f8b0464..97971ba6 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -17,8 +17,8 @@ from torch import nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.layers import trunc_normal_tf_ -from timm.models.layers import DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d +from .fx_features import register_notrace_module +from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from .registry import register_model @@ -53,6 +53,7 @@ default_cfgs = dict( ) +@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method class PositionalEncodingFourier(nn.Module): def __init__(self, hidden_dim=32, dim=768, temperature=10000): super().__init__() @@ -349,6 +350,7 @@ class EdgeNeXt(nn.Module): self.drop_rate = drop_rate norm_layer = partial(LayerNorm2d, eps=1e-6) norm_layer_cl = partial(nn.LayerNorm, eps=1e-6) + self.feature_info = [] assert stem_type in ('patch', 'overlap') if stem_type == 'patch': @@ -362,14 +364,18 @@ class EdgeNeXt(nn.Module): norm_layer(dims[0]), ) + curr_stride = 4 stages = [] dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] in_chs = dims[0] for i in range(4): + stride = 2 if curr_stride == 2 or i > 0 else 1 + # FIXME support dilation / output_stride + curr_stride *= stride stages.append(EdgeNeXtStage( in_chs=in_chs, out_chs=dims[i], - stride=2 if i > 0 else 1, + stride=stride, depth=depths[i], num_global_blocks=global_block_counts[i], num_heads=heads[i], @@ -385,7 +391,10 @@ class EdgeNeXt(nn.Module): norm_layer_cl=norm_layer_cl, act_layer=act_layer, )) + # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 in_chs = dims[i] + self.feature_info += [dict(num_chs=in_chs, reduction=curr_stride, module=f'stages.{i}')] + self.stages = nn.Sequential(*stages) self.num_features = dims[-1]