Skip to content

Fix GroupedLinear FP8 calibration loop#3101

Open
fallintoplace wants to merge 1 commit into
NVIDIA:mainfrom
fallintoplace:fix/grouped-linear-calibration
Open

Fix GroupedLinear FP8 calibration loop#3101
fallintoplace wants to merge 1 commit into
NVIDIA:mainfrom
fallintoplace:fix/grouped-linear-calibration

Conversation

@fallintoplace
Copy link
Copy Markdown

Summary

  • remove the unused outer calibration loop in GroupedLinear
  • calibrate each input and weight quantizer once per GEMM
  • avoid repeating calibration num_gemms times for every GEMM

Validation

  • ran python3 -m py_compile transformer_engine/pytorch/module/grouped_linear.py
  • ran git diff --check -- transformer_engine/pytorch/module/grouped_linear.py

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 6, 2026
@fallintoplace fallintoplace marked this pull request as ready for review June 6, 2026 14:40
@fallintoplace fallintoplace requested a review from ksivaman as a code owner June 6, 2026 14:40
Signed-off-by: Minh Vu <vuhoangminh97@gmail.com>
@fallintoplace fallintoplace force-pushed the fix/grouped-linear-calibration branch from 506ef2d to 08c86d7 Compare June 6, 2026 14:41
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 6, 2026

Greptile Summary

This PR fixes a bug in the FP8 calibration block of _GroupedLinear.forward where a redundant outer loop caused each quantizer's calibrate() to be called num_gemms times per tensor instead of once, and both inner loops reused the outer loop variable i, masking the redundancy.

  • Calibration loop restructured: the three nested loops (outer over num_gemms, then two inner loops also over num_gemms) are replaced with a single loop that calls input_quantizers[i].calibrate(inputmats[i]) and weight_quantizers[i].calibrate(weights[i]) exactly once per GEMM.
  • The general_grouped_gemm call and all surrounding logic are unchanged; only the fp8_calibration branch is affected.

Confidence Score: 5/5

Safe to merge — the change is a one-block fix that removes clearly erroneous nested loops, with no logic changes elsewhere in the file.

The original code had an outer loop whose variable i was immediately shadowed by two inner loops, causing every quantizer to be calibrated num_gemms times with the same tensor. The fix collapses the three loops into one correct loop. The change is minimal, easy to verify by inspection, and no other code paths are touched.

No files require special attention; the single changed block is straightforward.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Removes a spurious outer for i in range(num_gemms) loop in the fp8_calibration block; input and weight quantizers are now calibrated exactly once each per GEMM instead of num_gemms times.

Reviews (1): Last reviewed commit: "Fix grouped linear FP8 calibration loop" | Re-trigger Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant