Fix a few more issues related to #216 w/ TResNet (space2depth) and FP16 weights in wide resnets. Also don't completely dump pretrained weights in in_chans != 1 or 3 cases.

pull/227/head
Ross Wightman 4 years ago
parent 512b2dd645
commit b1b6e7c361

@ -115,8 +115,9 @@ if 'GITHUB_ACTIONS' not in os.environ:
@pytest.mark.parametrize('model_name', list_models(pretrained=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_load_pretrained(model_name, batch_size):
"""Run a single forward pass with each model"""
create_model(model_name, pretrained=True)
"""Create that pretrained weights load, verify support for in_chans != 3 while doing so."""
in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
create_model(model_name, pretrained=True, in_chans=in_chans)
EXCLUDE_JIT_FILTERS = [

@ -1,5 +1,10 @@
""" Model creation / weight loading / state_dict helpers
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import os
import math
from collections import OrderedDict
from copy import deepcopy
from typing import Callable
@ -86,11 +91,40 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
if in_chans == 1:
conv1_name = cfg['first_conv']
_logger.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
_logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
conv1_weight = state_dict[conv1_name + '.weight']
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
# Some weights are in torch.half, ensure it's float for sum on CPU
conv1_type = conv1_weight.dtype
conv1_weight = conv1_weight.float()
O, I, J, K = conv1_weight.shape
if I > 3:
assert conv1_weight.shape[1] % 3 == 0
# For models with space2depth stems
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
else:
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
conv1_weight = conv1_weight.to(conv1_type)
state_dict[conv1_name + '.weight'] = conv1_weight
elif in_chans != 3:
assert False, "Invalid in_chans for pretrained weights"
conv1_name = cfg['first_conv']
conv1_weight = state_dict[conv1_name + '.weight']
conv1_type = conv1_weight.dtype
conv1_weight = conv1_weight.float()
O, I, J, K = conv1_weight.shape
if I != 3:
_logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
del state_dict[conv1_name + '.weight']
strict = False
else:
# NOTE this strategy should be better than random init, but there could be other combinations of
# the original RGB input layer weights that'd work better for specific cases.
_logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
repeat = int(math.ceil(in_chans / 3))
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
conv1_weight *= (3 / float(in_chans))
conv1_weight = conv1_weight.to(conv1_type)
state_dict[conv1_name + '.weight'] = conv1_weight
classifier_name = cfg['classifier']
if num_classes == 1000 and cfg['num_classes'] == 1001:

Loading…
Cancel
Save