Minor test change

pull/154/head
Ross Wightman 5 years ago
parent afb6bd0669
commit 3873ea710e

@ -66,5 +66,5 @@ def test_model_default_cfgs(model_name, batch_size):
input_size = tuple([min(x, 448) for x in input_size])
outputs = model.forward_features(torch.randn((batch_size, *input_size)))
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
assert any([k.startswith(cfg['classifier']) for k in state_dict.keys()]), f'{classifier} not in model params'
assert any([k.startswith(cfg['first_conv']) for k in state_dict.keys()]), f'{first_conv} not in model params'
assert any([k.startswith(classifier) for k in state_dict.keys()]), f'{classifier} not in model params'
assert any([k.startswith(first_conv) for k in state_dict.keys()]), f'{first_conv} not in model params'

Loading…
Cancel
Save