Update num_classes in reset_classifier in resnetv2

pull/537/head
Mohamed Al Salti 4 years ago committed by GitHub
parent de9dff933a
commit bcb4879390
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -365,6 +365,7 @@ class ResNetV2(nn.Module):
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.head = ClassifierHead( self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True) self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)

Loading…
Cancel
Save