|
|
@ -365,6 +365,7 @@ class ResNetV2(nn.Module):
|
|
|
|
return self.head.fc
|
|
|
|
return self.head.fc
|
|
|
|
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
|
|
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
|
|
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.head = ClassifierHead(
|
|
|
|
self.head = ClassifierHead(
|
|
|
|
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
|
|
|
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
|
|
|
|
|
|
|
|
|
|
|