You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
2.0 KiB
60 lines
2.0 KiB
3 years ago
|
from torch.nn.modules.batchnorm import BatchNorm2d
|
||
|
from torchvision.ops.misc import FrozenBatchNorm2d
|
||
|
|
||
|
import timm
|
||
|
from timm.utils.model import freeze, unfreeze
|
||
|
|
||
|
|
||
|
def test_freeze_unfreeze():
|
||
|
model = timm.create_model('resnet18')
|
||
|
|
||
|
# Freeze all
|
||
|
freeze(model)
|
||
|
# Check top level module
|
||
|
assert model.fc.weight.requires_grad == False
|
||
|
# Check submodule
|
||
|
assert model.layer1[0].conv1.weight.requires_grad == False
|
||
|
# Check BN
|
||
|
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
|
||
|
|
||
|
# Unfreeze all
|
||
|
unfreeze(model)
|
||
|
# Check top level module
|
||
|
assert model.fc.weight.requires_grad == True
|
||
|
# Check submodule
|
||
|
assert model.layer1[0].conv1.weight.requires_grad == True
|
||
|
# Check BN
|
||
|
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|
||
|
|
||
|
# Freeze some
|
||
|
freeze(model, ['layer1', 'layer2.0'])
|
||
|
# Check frozen
|
||
|
assert model.layer1[0].conv1.weight.requires_grad == False
|
||
|
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
|
||
|
assert model.layer2[0].conv1.weight.requires_grad == False
|
||
|
# Check not frozen
|
||
|
assert model.layer3[0].conv1.weight.requires_grad == True
|
||
|
assert isinstance(model.layer3[0].bn1, BatchNorm2d)
|
||
|
assert model.layer2[1].conv1.weight.requires_grad == True
|
||
|
|
||
|
# Unfreeze some
|
||
|
unfreeze(model, ['layer1', 'layer2.0'])
|
||
|
# Check not frozen
|
||
|
assert model.layer1[0].conv1.weight.requires_grad == True
|
||
|
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|
||
|
assert model.layer2[0].conv1.weight.requires_grad == True
|
||
|
|
||
|
# Freeze BN
|
||
|
# From root
|
||
|
freeze(model, ['layer1.0.bn1'])
|
||
|
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
|
||
|
# From direct parent
|
||
|
freeze(model.layer1[0], ['bn1'])
|
||
|
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
|
||
|
|
||
|
# Unfreeze BN
|
||
|
unfreeze(model, ['layer1.0.bn1'])
|
||
|
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|
||
|
# From direct parent
|
||
|
unfreeze(model.layer1[0], ['bn1'])
|
||
|
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|