|
|
|
@ -7,10 +7,6 @@ attention in each block. The attention mechanisms used are linear in complexity.
|
|
|
|
|
|
|
|
|
|
DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
# Copyright (c) 2022 Mingyu Ding
|
|
|
|
|
# All rights reserved.
|
|
|
|
@ -442,7 +438,7 @@ class DaViT(nn.Module):
|
|
|
|
|
) for layer_id, item in enumerate(stage_param)
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
self.main_blocks.add_module(f'stage_{stage_id}', stage)
|
|
|
|
|
self.stages.add_module(f'stage_{stage_id}', stage)
|
|
|
|
|
|
|
|
|
|
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction = 2, module=f'stage_{stage_id}')]
|
|
|
|
|
|
|
|
|
@ -536,7 +532,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
|
|
|
|
|
out_dict = {}
|
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
|
k = k.replace('main_blocks.', 'main_blocks.stage_')
|
|
|
|
|
k = k.replace('main_blocks.', 'stages.stage_')
|
|
|
|
|
k = k.replace('head.', 'head.fc.')
|
|
|
|
|
out_dict[k] = v
|
|
|
|
|
return out_dict
|
|
|
|
@ -570,13 +566,13 @@ def _cfg(url='', **kwargs): # not sure how this should be set up
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = generate_default_cfgs({
|
|
|
|
|
|
|
|
|
|
'davit_tiny.msft_in1k': _cfg(
|
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_tiny_ed28dd55.pth.tar"),
|
|
|
|
|
'davit_small.msft_in1k': _cfg(
|
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"),
|
|
|
|
|
'davit_base.msft_in1k': _cfg(
|
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"),
|
|
|
|
|
# official microsoft weights from https://github.com/dingmyu/davit
|
|
|
|
|
'davit_tiny.msft_in1k': _cfg(
|
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_tiny_ed28dd55.pth.tar"),
|
|
|
|
|
'davit_small.msft_in1k': _cfg(
|
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"),
|
|
|
|
|
'davit_base.msft_in1k': _cfg(
|
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"),
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|