diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 80e0943d..9c332235 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -365,6 +365,7 @@ class ResNetV2(nn.Module): return self.head.fc def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)