@ -365,6 +365,7 @@ class ResNetV2(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)