Fix some broken tests for ResNetV2 BiT models

pull/323/head
Ross Wightman 4 years ago
parent fd9061dbf7
commit 20516abc18

@ -15,7 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', 'vit_*']
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', 'vit_*', '*in21k', '*152x4_bitm']
else:
EXCLUDE_FILTERS = ['vit_*']
MAX_FWD_SIZE = 384

@ -331,7 +331,7 @@ def create_stem(in_chs, out_chs, stem_type='', preact=True, conv_layer=None, nor
if 'fixed' in stem_type:
# 'fixed' SAME padding approximation that is used in BiT models
stem['pad'] = nn.ConstantPad2d(1, 0)
stem['pad'] = nn.ConstantPad2d(1, 0.)
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
elif 'same' in stem_type:
# full, input size based 'SAME' padding, used in ViT Hybrid model
@ -421,7 +421,12 @@ class ResNetV2(nn.Module):
import numpy as np
weights = np.load(checkpoint_path)
with torch.no_grad():
self.stem.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel']))
stem_conv_w = tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])
if self.stem.conv.weight.shape[1] == 1:
self.stem.conv.weight.copy_(stem_conv_w.sum(dim=1, keepdim=True))
# FIXME handle > 3 in_chans?
else:
self.stem.conv.weight.copy_(stem_conv_w)
self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))

Loading…
Cancel
Save