Add feature_info to edgenext for features_only support, hopefully fix some fx / test errors

pull/1327/head
Ross Wightman 2 years ago
parent 377e9bfa21
commit dd9b8f57c4

@ -17,8 +17,8 @@ from torch import nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import trunc_normal_tf_
from timm.models.layers import DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d
from .fx_features import register_notrace_module
from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .registry import register_model
@ -53,6 +53,7 @@ default_cfgs = dict(
)
@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
class PositionalEncodingFourier(nn.Module):
def __init__(self, hidden_dim=32, dim=768, temperature=10000):
super().__init__()
@ -349,6 +350,7 @@ class EdgeNeXt(nn.Module):
self.drop_rate = drop_rate
norm_layer = partial(LayerNorm2d, eps=1e-6)
norm_layer_cl = partial(nn.LayerNorm, eps=1e-6)
self.feature_info = []
assert stem_type in ('patch', 'overlap')
if stem_type == 'patch':
@ -362,14 +364,18 @@ class EdgeNeXt(nn.Module):
norm_layer(dims[0]),
)
curr_stride = 4
stages = []
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
in_chs = dims[0]
for i in range(4):
stride = 2 if curr_stride == 2 or i > 0 else 1
# FIXME support dilation / output_stride
curr_stride *= stride
stages.append(EdgeNeXtStage(
in_chs=in_chs,
out_chs=dims[i],
stride=2 if i > 0 else 1,
stride=stride,
depth=depths[i],
num_global_blocks=global_block_counts[i],
num_heads=heads[i],
@ -385,7 +391,10 @@ class EdgeNeXt(nn.Module):
norm_layer_cl=norm_layer_cl,
act_layer=act_layer,
))
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
in_chs = dims[i]
self.feature_info += [dict(num_chs=in_chs, reduction=curr_stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.num_features = dims[-1]

Loading…
Cancel
Save