@ -1,5 +1,10 @@
""" Model creation / weight loading / state_dict helpers
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import os
import math
from collections import OrderedDict
from copy import deepcopy
from typing import Callable
@ -86,11 +91,40 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
if in_chans == 1 :
conv1_name = cfg [ ' first_conv ' ]
_logger . info ( ' Converting first conv ( %s ) from 3 to 1 channel' % conv1_name )
_logger . info ( ' Converting first conv ( %s ) pretrained weights from 3 to 1 channel' % conv1_name )
conv1_weight = state_dict [ conv1_name + ' .weight ' ]
state_dict [ conv1_name + ' .weight ' ] = conv1_weight . sum ( dim = 1 , keepdim = True )
# Some weights are in torch.half, ensure it's float for sum on CPU
conv1_type = conv1_weight . dtype
conv1_weight = conv1_weight . float ( )
O , I , J , K = conv1_weight . shape
if I > 3 :
assert conv1_weight . shape [ 1 ] % 3 == 0
# For models with space2depth stems
conv1_weight = conv1_weight . reshape ( O , I / / 3 , 3 , J , K )
conv1_weight = conv1_weight . sum ( dim = 2 , keepdim = False )
else :
conv1_weight = conv1_weight . sum ( dim = 1 , keepdim = True )
conv1_weight = conv1_weight . to ( conv1_type )
state_dict [ conv1_name + ' .weight ' ] = conv1_weight
elif in_chans != 3 :
assert False , " Invalid in_chans for pretrained weights "
conv1_name = cfg [ ' first_conv ' ]
conv1_weight = state_dict [ conv1_name + ' .weight ' ]
conv1_type = conv1_weight . dtype
conv1_weight = conv1_weight . float ( )
O , I , J , K = conv1_weight . shape
if I != 3 :
_logger . warning ( ' Deleting first conv ( %s ) from pretrained weights. ' % conv1_name )
del state_dict [ conv1_name + ' .weight ' ]
strict = False
else :
# NOTE this strategy should be better than random init, but there could be other combinations of
# the original RGB input layer weights that'd work better for specific cases.
_logger . info ( ' Repeating first conv ( %s ) weights in channel dim. ' % conv1_name )
repeat = int ( math . ceil ( in_chans / 3 ) )
conv1_weight = conv1_weight . repeat ( 1 , repeat , 1 , 1 ) [ : , : in_chans , : , : ]
conv1_weight * = ( 3 / float ( in_chans ) )
conv1_weight = conv1_weight . to ( conv1_type )
state_dict [ conv1_name + ' .weight ' ] = conv1_weight
classifier_name = cfg [ ' classifier ' ]
if num_classes == 1000 and cfg [ ' num_classes ' ] == 1001 :