|
|
|
@ -331,7 +331,7 @@ def create_stem(in_chs, out_chs, stem_type='', preact=True, conv_layer=None, nor
|
|
|
|
|
|
|
|
|
|
if 'fixed' in stem_type:
|
|
|
|
|
# 'fixed' SAME padding approximation that is used in BiT models
|
|
|
|
|
stem['pad'] = nn.ConstantPad2d(1, 0)
|
|
|
|
|
stem['pad'] = nn.ConstantPad2d(1, 0.)
|
|
|
|
|
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
|
|
|
|
|
elif 'same' in stem_type:
|
|
|
|
|
# full, input size based 'SAME' padding, used in ViT Hybrid model
|
|
|
|
@ -421,7 +421,12 @@ class ResNetV2(nn.Module):
|
|
|
|
|
import numpy as np
|
|
|
|
|
weights = np.load(checkpoint_path)
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
self.stem.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel']))
|
|
|
|
|
stem_conv_w = tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])
|
|
|
|
|
if self.stem.conv.weight.shape[1] == 1:
|
|
|
|
|
self.stem.conv.weight.copy_(stem_conv_w.sum(dim=1, keepdim=True))
|
|
|
|
|
# FIXME handle > 3 in_chans?
|
|
|
|
|
else:
|
|
|
|
|
self.stem.conv.weight.copy_(stem_conv_w)
|
|
|
|
|
self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
|
|
|
|
|
self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
|
|
|
|
|
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
|
|
|
|
|