diff --git a/tests/test_utils.py b/tests/test_utils.py index 3e11eacc..b0f890d2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -44,17 +44,14 @@ def test_freeze_unfreeze(): assert isinstance(model.layer1[0].bn1, BatchNorm2d) assert model.layer2[0].conv1.weight.requires_grad == True - # Freeze BN + # Freeze/unfreeze BN # From root freeze(model, ['layer1.0.bn1']) assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) - # From direct parent - freeze(model.layer1[0], ['bn1']) - assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) - - # Unfreeze BN unfreeze(model, ['layer1.0.bn1']) assert isinstance(model.layer1[0].bn1, BatchNorm2d) # From direct parent + freeze(model.layer1[0], ['bn1']) + assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) unfreeze(model.layer1[0], ['bn1']) assert isinstance(model.layer1[0].bn1, BatchNorm2d) \ No newline at end of file