Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tests/pytorch/distributed/run_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,10 @@ def _get_tolerances(dtype):
if dtype == torch.bfloat16:
return {"rtol": 1.6e-2, "atol": 1e-5}
if dtype == torch.float32:
# TF32 has same mantissa bits as FP16
return {"rtol": 1e-3, "atol": 1e-5}
# TF32 has same mantissa bits as FP16. The atol is looser than for FP16
# because near-zero gradient elements can differ by a few 1e-5 between
# the TP-sharded and single-device GEMM reduction orders (observed on A100).
Comment on lines +210 to +212
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 Jun 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit disturbing. Even if TF32 has errors, shouldn't it be strictly better than FP16?

This makes me think there are other differences going on, like maybe the FP32 GEMM kernel is different between TP and non-TP, while it is consistent for FP16?

return {"rtol": 1e-3, "atol": 5e-5}
raise ValueError(f"Unsupported dtype ({dtype})")


Expand Down