Add feature_info to edgenext for features_only support, hopefully fix some fx / test errors

pull/1327/head
Ross Wightman 2 years ago
parent 377e9bfa21
commit dd9b8f57c4

@ -17,8 +17,8 @@ from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import trunc_normal_tf_ from .fx_features import register_notrace_module
from timm.models.layers import DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .registry import register_model from .registry import register_model
@ -53,6 +53,7 @@ default_cfgs = dict(
) )
@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
class PositionalEncodingFourier(nn.Module): class PositionalEncodingFourier(nn.Module):
def __init__(self, hidden_dim=32, dim=768, temperature=10000): def __init__(self, hidden_dim=32, dim=768, temperature=10000):
super().__init__() super().__init__()
@ -349,6 +350,7 @@ class EdgeNeXt(nn.Module):
self.drop_rate = drop_rate self.drop_rate = drop_rate
norm_layer = partial(LayerNorm2d, eps=1e-6) norm_layer = partial(LayerNorm2d, eps=1e-6)
norm_layer_cl = partial(nn.LayerNorm, eps=1e-6) norm_layer_cl = partial(nn.LayerNorm, eps=1e-6)
self.feature_info = []
assert stem_type in ('patch', 'overlap') assert stem_type in ('patch', 'overlap')
if stem_type == 'patch': if stem_type == 'patch':
@ -362,14 +364,18 @@ class EdgeNeXt(nn.Module):
norm_layer(dims[0]), norm_layer(dims[0]),
) )
curr_stride = 4
stages = [] stages = []
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
in_chs = dims[0] in_chs = dims[0]
for i in range(4): for i in range(4):
stride = 2 if curr_stride == 2 or i > 0 else 1
# FIXME support dilation / output_stride
curr_stride *= stride
stages.append(EdgeNeXtStage( stages.append(EdgeNeXtStage(
in_chs=in_chs, in_chs=in_chs,
out_chs=dims[i], out_chs=dims[i],
stride=2 if i > 0 else 1, stride=stride,
depth=depths[i], depth=depths[i],
num_global_blocks=global_block_counts[i], num_global_blocks=global_block_counts[i],
num_heads=heads[i], num_heads=heads[i],
@ -385,7 +391,10 @@ class EdgeNeXt(nn.Module):
norm_layer_cl=norm_layer_cl, norm_layer_cl=norm_layer_cl,
act_layer=act_layer, act_layer=act_layer,
)) ))
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
in_chs = dims[i] in_chs = dims[i]
self.feature_info += [dict(num_chs=in_chs, reduction=curr_stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages) self.stages = nn.Sequential(*stages)
self.num_features = dims[-1] self.num_features = dims[-1]

Loading…
Cancel
Save