diff --git a/train.py b/train.py index 1e95c831..217f9a88 100755 --- a/train.py +++ b/train.py @@ -732,6 +732,11 @@ def evaluate( elif dev_env.type_cuda: dev_env.synchronize() + # FIXME uncommenting this fixes race btw model `output`/`loss` and loss_m/accuracy_m meter input + # for PyTorch XLA GPU use. + # This issue does not exist for normal PyTorch w/ GPU (CUDA) or PyTorch XLA w/ TPU. + # loss.item() + tracker.mark_iter_step_end() losses_m.update(loss, output.size(0)) accuracy_m.update(output, target)