@ -1,5 +1,10 @@
""" Model creation / weight loading / state_dict helpers
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import logging
import os
import os
import math
from collections import OrderedDict
from collections import OrderedDict
from copy import deepcopy
from copy import deepcopy
from typing import Callable
from typing import Callable
@ -86,11 +91,40 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
if in_chans == 1 :
if in_chans == 1 :
conv1_name = cfg [ ' first_conv ' ]
conv1_name = cfg [ ' first_conv ' ]
_logger . info ( ' Converting first conv ( %s ) from 3 to 1 channel' % conv1_name )
_logger . info ( ' Converting first conv ( %s ) pretrained weights from 3 to 1 channel' % conv1_name )
conv1_weight = state_dict [ conv1_name + ' .weight ' ]
conv1_weight = state_dict [ conv1_name + ' .weight ' ]
state_dict [ conv1_name + ' .weight ' ] = conv1_weight . sum ( dim = 1 , keepdim = True )
# Some weights are in torch.half, ensure it's float for sum on CPU
conv1_type = conv1_weight . dtype
conv1_weight = conv1_weight . float ( )
O , I , J , K = conv1_weight . shape
if I > 3 :
assert conv1_weight . shape [ 1 ] % 3 == 0
# For models with space2depth stems
conv1_weight = conv1_weight . reshape ( O , I / / 3 , 3 , J , K )
conv1_weight = conv1_weight . sum ( dim = 2 , keepdim = False )
else :
conv1_weight = conv1_weight . sum ( dim = 1 , keepdim = True )
conv1_weight = conv1_weight . to ( conv1_type )
state_dict [ conv1_name + ' .weight ' ] = conv1_weight
elif in_chans != 3 :
elif in_chans != 3 :
assert False , " Invalid in_chans for pretrained weights "
conv1_name = cfg [ ' first_conv ' ]
conv1_weight = state_dict [ conv1_name + ' .weight ' ]
conv1_type = conv1_weight . dtype
conv1_weight = conv1_weight . float ( )
O , I , J , K = conv1_weight . shape
if I != 3 :
_logger . warning ( ' Deleting first conv ( %s ) from pretrained weights. ' % conv1_name )
del state_dict [ conv1_name + ' .weight ' ]
strict = False
else :
# NOTE this strategy should be better than random init, but there could be other combinations of
# the original RGB input layer weights that'd work better for specific cases.
_logger . info ( ' Repeating first conv ( %s ) weights in channel dim. ' % conv1_name )
repeat = int ( math . ceil ( in_chans / 3 ) )
conv1_weight = conv1_weight . repeat ( 1 , repeat , 1 , 1 ) [ : , : in_chans , : , : ]
conv1_weight * = ( 3 / float ( in_chans ) )
conv1_weight = conv1_weight . to ( conv1_type )
state_dict [ conv1_name + ' .weight ' ] = conv1_weight
classifier_name = cfg [ ' classifier ' ]
classifier_name = cfg [ ' classifier ' ]
if num_classes == 1000 and cfg [ ' num_classes ' ] == 1001 :
if num_classes == 1000 and cfg [ ' num_classes ' ] == 1001 :