diff --git a/tests/test_models.py b/tests/test_models.py index c4ba8db3..f5698462 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -115,8 +115,9 @@ if 'GITHUB_ACTIONS' not in os.environ: @pytest.mark.parametrize('model_name', list_models(pretrained=True)) @pytest.mark.parametrize('batch_size', [1]) def test_model_load_pretrained(model_name, batch_size): - """Run a single forward pass with each model""" - create_model(model_name, pretrained=True) + """Create that pretrained weights load, verify support for in_chans != 3 while doing so.""" + in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change + create_model(model_name, pretrained=True, in_chans=in_chans) EXCLUDE_JIT_FILTERS = [ diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 60888205..f1702af0 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -1,5 +1,10 @@ +""" Model creation / weight loading / state_dict helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" import logging import os +import math from collections import OrderedDict from copy import deepcopy from typing import Callable @@ -86,11 +91,40 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non if in_chans == 1: conv1_name = cfg['first_conv'] - _logger.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name) + _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name) conv1_weight = state_dict[conv1_name + '.weight'] - state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True) + # Some weights are in torch.half, ensure it's float for sum on CPU + conv1_type = conv1_weight.dtype + conv1_weight = conv1_weight.float() + O, I, J, K = conv1_weight.shape + if I > 3: + assert conv1_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) + conv1_weight = conv1_weight.sum(dim=2, keepdim=False) + else: + conv1_weight = conv1_weight.sum(dim=1, keepdim=True) + conv1_weight = conv1_weight.to(conv1_type) + state_dict[conv1_name + '.weight'] = conv1_weight elif in_chans != 3: - assert False, "Invalid in_chans for pretrained weights" + conv1_name = cfg['first_conv'] + conv1_weight = state_dict[conv1_name + '.weight'] + conv1_type = conv1_weight.dtype + conv1_weight = conv1_weight.float() + O, I, J, K = conv1_weight.shape + if I != 3: + _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name) + del state_dict[conv1_name + '.weight'] + strict = False + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name) + repeat = int(math.ceil(in_chans / 3)) + conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv1_weight *= (3 / float(in_chans)) + conv1_weight = conv1_weight.to(conv1_type) + state_dict[conv1_name + '.weight'] = conv1_weight classifier_name = cfg['classifier'] if num_classes == 1000 and cfg['num_classes'] == 1001: