From f88c606fcf47b6534241950e5ab9a1ce5cd7a5c5 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 25 Apr 2022 12:41:46 -0700 Subject: [PATCH] fixing channels_last on cond_conv2d; update nvfuser debug env variable --- timm/models/layers/cond_conv2d.py | 3 ++- timm/utils/jit.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/timm/models/layers/cond_conv2d.py b/timm/models/layers/cond_conv2d.py index 8b4bbca8..43654c59 100644 --- a/timm/models/layers/cond_conv2d.py +++ b/timm/models/layers/cond_conv2d.py @@ -91,7 +91,8 @@ class CondConv2d(nn.Module): bias = torch.matmul(routing_weights, self.bias) bias = bias.view(B * self.out_channels) # move batch elements with channels so each batch element can be efficiently convolved with separate kernel - x = x.view(1, B * C, H, W) + # reshape instead of view to work with channels_last input + x = x.reshape(1, B * C, H, W) if self.dynamic_padding: out = conv2d_same( x, weight, bias, stride=self.stride, padding=self.padding, diff --git a/timm/utils/jit.py b/timm/utils/jit.py index 6039823f..8ebfdbff 100644 --- a/timm/utils/jit.py +++ b/timm/utils/jit.py @@ -34,9 +34,9 @@ def set_jit_fuser(fuser): torch._C._jit_override_can_fuse_on_gpu(True) torch._C._jit_set_texpr_fuser_enabled(False) elif fuser == "nvfuser" or fuser == "nvf": - os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1' - os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1' - os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0' + os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '1' + os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1' + os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0' torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(True)