From 87e2033b1d2ef242e9ad8ec98fc69a858d9ff61f Mon Sep 17 00:00:00 2001 From: "Reddy, Yeshwanth" Date: Sun, 15 Aug 2021 17:43:22 +0530 Subject: [PATCH] topk bug fix for binary classification --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index f1c1581e..4bcdaf3c 100755 --- a/train.py +++ b/train.py @@ -761,7 +761,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') target = target[0:target.size(0):reduce_factor] loss = loss_fn(output, target) - acc1, acc5 = accuracy(output, target, topk=(1, 5)) + acc1, acc5 = accuracy(output, target, topk=(1, min(5, model.num_classes))) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size)