|
|
@ -71,17 +71,18 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=Fals
|
|
|
|
module.init_weights()
|
|
|
|
module.init_weights()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_stage(index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer,
|
|
|
|
def get_stage(
|
|
|
|
|
|
|
|
index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer,
|
|
|
|
norm_layer, act_layer, num_layers, bidirectional, union,
|
|
|
|
norm_layer, act_layer, num_layers, bidirectional, union,
|
|
|
|
with_fc, drop=0., drop_path_rate=0., **kwargs):
|
|
|
|
with_fc, drop=0., drop_path_rate=0., **kwargs):
|
|
|
|
assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios)
|
|
|
|
assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios)
|
|
|
|
blocks = []
|
|
|
|
blocks = []
|
|
|
|
for block_idx in range(layers[index]):
|
|
|
|
for block_idx in range(layers[index]):
|
|
|
|
drop_path = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
|
|
|
|
drop_path = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
|
|
|
|
blocks.append(block_layer(embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index],
|
|
|
|
blocks.append(block_layer(
|
|
|
|
rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer,
|
|
|
|
embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index],
|
|
|
|
act_layer=act_layer, num_layers=num_layers,
|
|
|
|
rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer,
|
|
|
|
bidirectional=bidirectional, union=union, with_fc=with_fc,
|
|
|
|
num_layers=num_layers, bidirectional=bidirectional, union=union, with_fc=with_fc,
|
|
|
|
drop=drop, drop_path=drop_path))
|
|
|
|
drop=drop, drop_path=drop_path))
|
|
|
|
|
|
|
|
|
|
|
|
if index < len(embed_dims) - 1:
|
|
|
|
if index < len(embed_dims) - 1:
|
|
|
@ -101,7 +102,8 @@ class RNNIdentity(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
class RNN2DBase(nn.Module):
|
|
|
|
class RNN2DBase(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, input_size: int, hidden_size: int,
|
|
|
|
def __init__(
|
|
|
|
|
|
|
|
self, input_size: int, hidden_size: int,
|
|
|
|
num_layers: int = 1, bias: bool = True, bidirectional: bool = True,
|
|
|
|
num_layers: int = 1, bias: bool = True, bidirectional: bool = True,
|
|
|
|
union="cat", with_fc=True):
|
|
|
|
union="cat", with_fc=True):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
@ -115,6 +117,7 @@ class RNN2DBase(nn.Module):
|
|
|
|
self.with_horizontal = True
|
|
|
|
self.with_horizontal = True
|
|
|
|
self.with_fc = with_fc
|
|
|
|
self.with_fc = with_fc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.fc = None
|
|
|
|
if with_fc:
|
|
|
|
if with_fc:
|
|
|
|
if union == "cat":
|
|
|
|
if union == "cat":
|
|
|
|
self.fc = nn.Linear(2 * self.output_size, input_size)
|
|
|
|
self.fc = nn.Linear(2 * self.output_size, input_size)
|
|
|
@ -159,23 +162,27 @@ class RNN2DBase(nn.Module):
|
|
|
|
v, _ = self.rnn_v(v)
|
|
|
|
v, _ = self.rnn_v(v)
|
|
|
|
v = v.reshape(B, W, H, -1)
|
|
|
|
v = v.reshape(B, W, H, -1)
|
|
|
|
v = v.permute(0, 2, 1, 3)
|
|
|
|
v = v.permute(0, 2, 1, 3)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
v = None
|
|
|
|
|
|
|
|
|
|
|
|
if self.with_horizontal:
|
|
|
|
if self.with_horizontal:
|
|
|
|
h = x.reshape(-1, W, C)
|
|
|
|
h = x.reshape(-1, W, C)
|
|
|
|
h, _ = self.rnn_h(h)
|
|
|
|
h, _ = self.rnn_h(h)
|
|
|
|
h = h.reshape(B, H, W, -1)
|
|
|
|
h = h.reshape(B, H, W, -1)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
h = None
|
|
|
|
|
|
|
|
|
|
|
|
if self.with_vertical and self.with_horizontal:
|
|
|
|
if v is not None and h is not None:
|
|
|
|
if self.union == "cat":
|
|
|
|
if self.union == "cat":
|
|
|
|
x = torch.cat([v, h], dim=-1)
|
|
|
|
x = torch.cat([v, h], dim=-1)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
x = v + h
|
|
|
|
x = v + h
|
|
|
|
elif self.with_vertical:
|
|
|
|
elif v is not None:
|
|
|
|
x = v
|
|
|
|
x = v
|
|
|
|
elif self.with_horizontal:
|
|
|
|
elif h is not None:
|
|
|
|
x = h
|
|
|
|
x = h
|
|
|
|
|
|
|
|
|
|
|
|
if self.with_fc:
|
|
|
|
if self.fc is not None:
|
|
|
|
x = self.fc(x)
|
|
|
|
x = self.fc(x)
|
|
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
return x
|
|
|
@ -183,7 +190,8 @@ class RNN2DBase(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
class LSTM2D(RNN2DBase):
|
|
|
|
class LSTM2D(RNN2DBase):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, input_size: int, hidden_size: int,
|
|
|
|
def __init__(
|
|
|
|
|
|
|
|
self, input_size: int, hidden_size: int,
|
|
|
|
num_layers: int = 1, bias: bool = True, bidirectional: bool = True,
|
|
|
|
num_layers: int = 1, bias: bool = True, bidirectional: bool = True,
|
|
|
|
union="cat", with_fc=True):
|
|
|
|
union="cat", with_fc=True):
|
|
|
|
super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc)
|
|
|
|
super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc)
|
|
|
@ -194,10 +202,10 @@ class LSTM2D(RNN2DBase):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Sequencer2DBlock(nn.Module):
|
|
|
|
class Sequencer2DBlock(nn.Module):
|
|
|
|
def __init__(self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp,
|
|
|
|
def __init__(
|
|
|
|
|
|
|
|
self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp,
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU,
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU,
|
|
|
|
num_layers=1, bidirectional=True, union="cat", with_fc=True,
|
|
|
|
num_layers=1, bidirectional=True, union="cat", with_fc=True, drop=0., drop_path=0.):
|
|
|
|
drop=0., drop_path=0.):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
channels_dim = int(mlp_ratio * dim)
|
|
|
|
channels_dim = int(mlp_ratio * dim)
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
@ -255,6 +263,7 @@ class Sequencer2D(nn.Module):
|
|
|
|
num_classes=1000,
|
|
|
|
num_classes=1000,
|
|
|
|
img_size=224,
|
|
|
|
img_size=224,
|
|
|
|
in_chans=3,
|
|
|
|
in_chans=3,
|
|
|
|
|
|
|
|
global_pool='avg',
|
|
|
|
layers=[4, 3, 8, 3],
|
|
|
|
layers=[4, 3, 8, 3],
|
|
|
|
patch_sizes=[7, 2, 1, 1],
|
|
|
|
patch_sizes=[7, 2, 1, 1],
|
|
|
|
embed_dims=[192, 384, 384, 384],
|
|
|
|
embed_dims=[192, 384, 384, 384],
|
|
|
@ -275,7 +284,9 @@ class Sequencer2D(nn.Module):
|
|
|
|
stem_norm=False,
|
|
|
|
stem_norm=False,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
assert global_pool in ('', 'avg')
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
|
|
|
self.global_pool = global_pool
|
|
|
|
self.num_features = embed_dims[-1] # num_features for consistency with other models
|
|
|
|
self.num_features = embed_dims[-1] # num_features for consistency with other models
|
|
|
|
self.embed_dims = embed_dims
|
|
|
|
self.embed_dims = embed_dims
|
|
|
|
self.stem = PatchEmbed(
|
|
|
|
self.stem = PatchEmbed(
|
|
|
@ -301,38 +312,54 @@ class Sequencer2D(nn.Module):
|
|
|
|
head_bias = -math.log(self.num_classes) if nlhb else 0.
|
|
|
|
head_bias = -math.log(self.num_classes) if nlhb else 0.
|
|
|
|
named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
|
|
|
|
named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
|
|
|
def group_matcher(self, coarse=False):
|
|
|
|
|
|
|
|
return dict(
|
|
|
|
|
|
|
|
stem=r'^stem',
|
|
|
|
|
|
|
|
blocks=[
|
|
|
|
|
|
|
|
(r'^blocks\.(\d+)\..*\.down', (99999,)),
|
|
|
|
|
|
|
|
(r'^blocks\.(\d+)', None) if coarse else (r'^blocks\.(\d+)\.(\d+)', None),
|
|
|
|
|
|
|
|
(r'^norm', (99999,))
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
|
|
|
def set_grad_checkpointing(self, enable=True):
|
|
|
|
|
|
|
|
assert not enable, 'gradient checkpointing not supported'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
def get_classifier(self):
|
|
|
|
def get_classifier(self):
|
|
|
|
return self.head
|
|
|
|
return self.head
|
|
|
|
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool=''):
|
|
|
|
def reset_classifier(self, num_classes, global_pool=None):
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
|
|
|
if self.global_pool is not None:
|
|
|
|
|
|
|
|
assert global_pool in ('', 'avg')
|
|
|
|
|
|
|
|
self.global_pool = global_pool
|
|
|
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
def forward_features(self, x):
|
|
|
|
x = self.stem(x)
|
|
|
|
x = self.stem(x)
|
|
|
|
x = self.blocks(x)
|
|
|
|
x = self.blocks(x)
|
|
|
|
x = self.norm(x)
|
|
|
|
x = self.norm(x)
|
|
|
|
x = x.mean(dim=(1, 2))
|
|
|
|
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
|
|
|
|
if self.global_pool == 'avg':
|
|
|
|
|
|
|
|
x = x.mean(dim=(1, 2))
|
|
|
|
|
|
|
|
return x if pre_logits else self.head(x)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.head(x)
|
|
|
|
x = self.forward_head(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_sequencer2d(variant, pretrained=False, **kwargs):
|
|
|
|
def _create_sequencer2d(variant, pretrained=False, **kwargs):
|
|
|
|
if kwargs.get('features_only', None):
|
|
|
|
if kwargs.get('features_only', None):
|
|
|
|
raise RuntimeError('features_only not implemented for Sequencer2D models.')
|
|
|
|
raise RuntimeError('features_only not implemented for Sequencer2D models.')
|
|
|
|
|
|
|
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
model = build_model_with_cfg(Sequencer2D, variant, pretrained, **kwargs)
|
|
|
|
Sequencer2D, variant, pretrained,
|
|
|
|
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn,
|
|
|
|
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|