Merge branch 'main' into davit

pull/1630/head
Fredo Guan 3 years ago committed by GitHub
commit 11f27df29f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,9 +17,11 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# This source code is licensed under the MIT license # This source code is licensed under the MIT license
import itertools import itertools
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -35,35 +37,9 @@ from .registry import register_model
__all__ = ['DaViT'] __all__ = ['DaViT']
'''
class MySequential(nn.Sequential):
def forward(self, inputs : Tuple[Tensor, Tuple[int, int]]):
for module in self:
inputs = module(*inputs)
#if type(inputs) == tuple:
# inputs = module(*inputs)
#else:
# inputs = module(inputs)
return inputs
'''
'''
class MySequential(nn.Sequential):
@overload
def forward(self, inputs : Tensor):
for module in self._modules.values():
inputs = module(inputs)
return inputs
@overload
def forward(self, inputs : Tuple[Tensor, Tensor]):
for module in self._modules.values():
inputs = module(*inputs)
return inputs
'''
class ConvPosEnc(nn.Module): class ConvPosEnc(nn.Module):
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
super(ConvPosEnc, self).__init__() super(ConvPosEnc, self).__init__()
self.proj = nn.Conv2d(dim, self.proj = nn.Conv2d(dim,
dim, dim,
@ -130,6 +106,7 @@ class PatchEmbed(nn.Module):
padding=to_2tuple(pad)) padding=to_2tuple(pad))
self.norm = nn.LayerNorm(in_chans) self.norm = nn.LayerNorm(in_chans)
def forward(self, x, size: Tuple[int, int]): def forward(self, x, size: Tuple[int, int]):
H, W = size H, W = size
dim = len(x.shape) dim = len(x.shape)
@ -203,6 +180,7 @@ class ChannelBlock(nn.Module):
hidden_features=mlp_hidden_dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer) act_layer=act_layer)
def forward(self, x, size: Tuple[int, int]): def forward(self, x, size: Tuple[int, int]):
x = self.cpe[0](x, size) x = self.cpe[0](x, size)
cur = self.norm1(x) cur = self.norm1(x)
@ -327,7 +305,9 @@ class SpatialBlock(nn.Module):
hidden_features=mlp_hidden_dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer) act_layer=act_layer)
def forward(self, x, size: Tuple[int, int]): def forward(self, x, size: Tuple[int, int]):
H, W = size H, W = size
B, L, C = x.shape B, L, C = x.shape
assert L == H * W, "input feature has wrong size" assert L == H * W, "input feature has wrong size"
@ -540,6 +520,7 @@ class DaViT(nn.Module):
branches.append(branch_id) branches.append(branch_id)
block_index : int = block_index block_index : int = block_index
if block_index not in branches: if block_index not in branches:
@ -552,11 +533,14 @@ class DaViT(nn.Module):
for layer_index, branch_id in enumerate(block_param): for layer_index, branch_id in enumerate(block_param):
layer_index : int = layer_index layer_index : int = layer_index
branch_id : int = branch_id branch_id : int = branch_id
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id]) features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id])
else: else:
features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id]) features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id])
# pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model # pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model
outs = [] outs = []
for i in range(self.num_stages): for i in range(self.num_stages):
@ -568,6 +552,7 @@ class DaViT(nn.Module):
''' '''

Loading…
Cancel
Save