@ -266,6 +266,10 @@ class BasicLayer(nn.Module):
)
for i in range(depth)])
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x):
x = self.downsample(x)
for blk in self.blocks: