Fix visformer in_chans stem handling

pull/637/head
Ross Wightman 4 years ago
parent fd92ba0de8
commit 5db7452173

@ -190,7 +190,7 @@ EXCLUDE_JIT_FILTERS = [
def test_model_forward_torchscript(model_name, batch_size):
"""Run a single forward pass with each model"""
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")
with set_scriptable(True):

@ -26,7 +26,7 @@ def _cfg(url='', **kwargs):
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
'first_conv': 'stem.0', 'classifier': 'head',
**kwargs
}
@ -183,7 +183,7 @@ class Visformer(nn.Module):
img_size //= 8
else:
self.stem = nn.Sequential(
nn.Conv2d(3, self.init_channels, 7, stride=2, padding=3, bias=False),
nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(self.init_channels),
nn.ReLU(inplace=True)
)

Loading…
Cancel
Save