topk bug fix for binary classification

pull/808/head
Reddy, Yeshwanth 4 years ago
parent 3cdaf5ed56
commit 87e2033b1d

@ -761,7 +761,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
target = target[0:target.size(0):reduce_factor]
loss = loss_fn(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
acc1, acc5 = accuracy(output, target, topk=(1, min(5, model.num_classes)))
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)

Loading…
Cancel
Save