diff --git a/train.py b/train.py index f1c1581e..4bcdaf3c 100755 --- a/train.py +++ b/train.py @@ -761,7 +761,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') target = target[0:target.size(0):reduce_factor] loss = loss_fn(output, target) - acc1, acc5 = accuracy(output, target, topk=(1, 5)) + acc1, acc5 = accuracy(output, target, topk=(1, min(5, model.num_classes))) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size)