|
|
@ -44,17 +44,14 @@ def test_freeze_unfreeze():
|
|
|
|
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|
|
|
|
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|
|
|
|
assert model.layer2[0].conv1.weight.requires_grad == True
|
|
|
|
assert model.layer2[0].conv1.weight.requires_grad == True
|
|
|
|
|
|
|
|
|
|
|
|
# Freeze BN
|
|
|
|
# Freeze/unfreeze BN
|
|
|
|
# From root
|
|
|
|
# From root
|
|
|
|
freeze(model, ['layer1.0.bn1'])
|
|
|
|
freeze(model, ['layer1.0.bn1'])
|
|
|
|
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
|
|
|
|
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
|
|
|
|
# From direct parent
|
|
|
|
|
|
|
|
freeze(model.layer1[0], ['bn1'])
|
|
|
|
|
|
|
|
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Unfreeze BN
|
|
|
|
|
|
|
|
unfreeze(model, ['layer1.0.bn1'])
|
|
|
|
unfreeze(model, ['layer1.0.bn1'])
|
|
|
|
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|
|
|
|
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|
|
|
|
# From direct parent
|
|
|
|
# From direct parent
|
|
|
|
|
|
|
|
freeze(model.layer1[0], ['bn1'])
|
|
|
|
|
|
|
|
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
|
|
|
|
unfreeze(model.layer1[0], ['bn1'])
|
|
|
|
unfreeze(model.layer1[0], ['bn1'])
|
|
|
|
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|
|
|
|
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|