diff --git a/train.py b/train.py index 7ab934b4..712e2055 100755 --- a/train.py +++ b/train.py @@ -693,6 +693,11 @@ def train_one_epoch( losses_m.update(reduced_loss.item(), input.size(0)) if args.local_rank == 0: + if (args.gpu_load != None): + total_input_size = input.size(0) / args.gpu_load[0] * args.world_size + else: + total_input_size = input.size(0) * args.world_size + _logger.info( 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' @@ -705,8 +710,8 @@ def train_one_epoch( 100. * batch_idx / last_idx, loss=losses_m, batch_time=batch_time_m, - rate=input.size(0) * args.world_size / batch_time_m.val, - rate_avg=input.size(0) * args.world_size / batch_time_m.avg, + rate=total_input_size / batch_time_m.val, + rate_avg=total_input_size / batch_time_m.avg, lr=lr, data_time=data_time_m))