Remove old mean/std helper, rely fully on cmd line or default_cfg now. Fixes #126

pull/136/head
Ross Wightman 4 years ago
parent 02a30411ad
commit 64fe37d008

@ -35,8 +35,6 @@ def resolve_data_config(args, default_cfg={}, model=None, verbose=True):
# resolve dataset + model mean for normalization
new_config['mean'] = IMAGENET_DEFAULT_MEAN
if 'model' in args:
new_config['mean'] = get_mean_by_model(args['model'])
if 'mean' in args and args['mean'] is not None:
mean = tuple(args['mean'])
if len(mean) == 1:
@ -49,8 +47,6 @@ def resolve_data_config(args, default_cfg={}, model=None, verbose=True):
# resolve dataset + model std deviation for normalization
new_config['std'] = IMAGENET_DEFAULT_STD
if 'model' in args:
new_config['std'] = get_std_by_model(args['model'])
if 'std' in args and args['std'] is not None:
std = tuple(args['std'])
if len(std) == 1:
@ -74,23 +70,3 @@ def resolve_data_config(args, default_cfg={}, model=None, verbose=True):
logging.info('\t%s: %s' % (n, str(v)))
return new_config
def get_mean_by_model(model_name):
model_name = model_name.lower()
if 'dpn' in model_name:
return IMAGENET_DPN_STD
elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name):
return IMAGENET_INCEPTION_MEAN
else:
return IMAGENET_DEFAULT_MEAN
def get_std_by_model(model_name):
model_name = model_name.lower()
if 'dpn' in model_name:
return IMAGENET_DEFAULT_STD
elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name):
return IMAGENET_INCEPTION_STD
else:
return IMAGENET_DEFAULT_STD

Loading…
Cancel
Save