Merge branch 'main' into davit

pull/1630/head
Fredo Guan 3 years ago committed by GitHub
commit 11f27df29f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,9 +17,11 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# This source code is licensed under the MIT license
import itertools
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -35,35 +37,9 @@ from .registry import register_model
__all__ = ['DaViT']
'''
class MySequential(nn.Sequential):
def forward(self, inputs : Tuple[Tensor, Tuple[int, int]]):
for module in self:
inputs = module(*inputs)
#if type(inputs) == tuple:
# inputs = module(*inputs)
#else:
# inputs = module(inputs)
return inputs
'''
'''
class MySequential(nn.Sequential):
@overload
def forward(self, inputs : Tensor):
for module in self._modules.values():
inputs = module(inputs)
return inputs
@overload
def forward(self, inputs : Tuple[Tensor, Tensor]):
for module in self._modules.values():
inputs = module(*inputs)
return inputs
'''
class ConvPosEnc(nn.Module):
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
super(ConvPosEnc, self).__init__()
self.proj = nn.Conv2d(dim,
dim,
@ -130,6 +106,7 @@ class PatchEmbed(nn.Module):
padding=to_2tuple(pad))
self.norm = nn.LayerNorm(in_chans)
def forward(self, x, size: Tuple[int, int]):
H, W = size
dim = len(x.shape)
@ -203,6 +180,7 @@ class ChannelBlock(nn.Module):
hidden_features=mlp_hidden_dim,
act_layer=act_layer)
def forward(self, x, size: Tuple[int, int]):
x = self.cpe[0](x, size)
cur = self.norm1(x)
@ -327,7 +305,9 @@ class SpatialBlock(nn.Module):
hidden_features=mlp_hidden_dim,
act_layer=act_layer)
def forward(self, x, size: Tuple[int, int]):
H, W = size
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
@ -538,6 +518,7 @@ class DaViT(nn.Module):
features.append(x)
sizes.append(size)
branches.append(branch_id)
block_index : int = block_index
@ -552,11 +533,14 @@ class DaViT(nn.Module):
for layer_index, branch_id in enumerate(block_param):
layer_index : int = layer_index
branch_id : int = branch_id
if self.grad_checkpointing and not torch.jit.is_scripting():
features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id])
else:
features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id])
# pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model
outs = []
for i in range(self.num_stages):
@ -565,6 +549,7 @@ class DaViT(nn.Module):
H, W = sizes[i]
out = x_out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)

Loading…
Cancel
Save