|
|
@ -17,7 +17,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
|
|
|
|
# This source code is licensed under the MIT license
|
|
|
|
# This source code is licensed under the MIT license
|
|
|
|
|
|
|
|
|
|
|
|
import itertools
|
|
|
|
import itertools
|
|
|
|
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, TypeVar, Union, List
|
|
|
|
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
@ -36,11 +36,28 @@ from .registry import register_model
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
|
|
|
|
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
|
|
|
|
def forward(self, inputs : Tuple):
|
|
|
|
|
|
|
|
for module in self._modules.values():
|
|
|
|
|
|
|
|
if type(inputs) == tuple:
|
|
|
|
|
|
|
|
inputs = module(*inputs)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
inputs = module(inputs)
|
|
|
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
|
|
|
|
@overload
|
|
|
|
|
|
|
|
def forward(self, inputs : Tensor):
|
|
|
|
|
|
|
|
for module in self._modules.values():
|
|
|
|
|
|
|
|
inputs = module(inputs)
|
|
|
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@overload
|
|
|
|
def forward(self, inputs : Tuple[Tensor, Tensor]):
|
|
|
|
def forward(self, inputs : Tuple[Tensor, Tensor]):
|
|
|
|
for module in self._modules.values():
|
|
|
|
for module in self._modules.values():
|
|
|
|
inputs = module(*inputs)
|
|
|
|
inputs = module(*inputs)
|
|
|
|
return inputs
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
class ConvPosEnc(nn.Module):
|
|
|
|
class ConvPosEnc(nn.Module):
|
|
|
|
def __init__(self, dim, k=3, act=False, normtype=False):
|
|
|
|
def __init__(self, dim, k=3, act=False, normtype=False):
|
|
|
|
super(ConvPosEnc, self).__init__()
|
|
|
|
super(ConvPosEnc, self).__init__()
|
|
|
@ -539,6 +556,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
state_dict = state_dict['state_dict']
|
|
|
|
state_dict = state_dict['state_dict']
|
|
|
|
|
|
|
|
|
|
|
|
out_dict = {}
|
|
|
|
out_dict = {}
|
|
|
|
|
|
|
|
import re
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
|
|
|
|
|
|
|
|
k = k.replace('head.', 'head.fc.')
|
|
|
|
k = k.replace('head.', 'head.fc.')
|
|
|
|