|
|
@ -37,6 +37,6 @@ class RealLabelsImagenet:
|
|
|
|
|
|
|
|
|
|
|
|
def get_accuracy(self, k=None):
|
|
|
|
def get_accuracy(self, k=None):
|
|
|
|
if k is None:
|
|
|
|
if k is None:
|
|
|
|
return {k: float(np.mean(self.is_correct[k] for k in self.topk))}
|
|
|
|
return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk}
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
return float(np.mean(self.is_correct[k])) * 100
|
|
|
|
return float(np.mean(self.is_correct[k])) * 100
|
|
|
|