Fix a few more issues related to #216 w/ TResNet (space2depth) and FP16 weights in wide resnets. Also don't completely dump pretrained weights in in_chans != 1 or 3 cases.

pull/227/head
Ross Wightman 4 years ago
parent 512b2dd645
commit b1b6e7c361

@ -115,8 +115,9 @@ if 'GITHUB_ACTIONS' not in os.environ:
@pytest.mark.parametrize('model_name', list_models(pretrained=True)) @pytest.mark.parametrize('model_name', list_models(pretrained=True))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
def test_model_load_pretrained(model_name, batch_size): def test_model_load_pretrained(model_name, batch_size):
"""Run a single forward pass with each model""" """Create that pretrained weights load, verify support for in_chans != 3 while doing so."""
create_model(model_name, pretrained=True) in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
create_model(model_name, pretrained=True, in_chans=in_chans)
EXCLUDE_JIT_FILTERS = [ EXCLUDE_JIT_FILTERS = [

@ -1,5 +1,10 @@
""" Model creation / weight loading / state_dict helpers
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging import logging
import os import os
import math
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from typing import Callable from typing import Callable
@ -86,11 +91,40 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
if in_chans == 1: if in_chans == 1:
conv1_name = cfg['first_conv'] conv1_name = cfg['first_conv']
_logger.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name) _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
conv1_weight = state_dict[conv1_name + '.weight'] conv1_weight = state_dict[conv1_name + '.weight']
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True) # Some weights are in torch.half, ensure it's float for sum on CPU
conv1_type = conv1_weight.dtype
conv1_weight = conv1_weight.float()
O, I, J, K = conv1_weight.shape
if I > 3:
assert conv1_weight.shape[1] % 3 == 0
# For models with space2depth stems
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
else:
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
conv1_weight = conv1_weight.to(conv1_type)
state_dict[conv1_name + '.weight'] = conv1_weight
elif in_chans != 3: elif in_chans != 3:
assert False, "Invalid in_chans for pretrained weights" conv1_name = cfg['first_conv']
conv1_weight = state_dict[conv1_name + '.weight']
conv1_type = conv1_weight.dtype
conv1_weight = conv1_weight.float()
O, I, J, K = conv1_weight.shape
if I != 3:
_logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
del state_dict[conv1_name + '.weight']
strict = False
else:
# NOTE this strategy should be better than random init, but there could be other combinations of
# the original RGB input layer weights that'd work better for specific cases.
_logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
repeat = int(math.ceil(in_chans / 3))
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
conv1_weight *= (3 / float(in_chans))
conv1_weight = conv1_weight.to(conv1_type)
state_dict[conv1_name + '.weight'] = conv1_weight
classifier_name = cfg['classifier'] classifier_name = cfg['classifier']
if num_classes == 1000 and cfg['num_classes'] == 1001: if num_classes == 1000 and cfg['num_classes'] == 1001:

Loading…
Cancel
Save