From 08c86d75c68ab2fc896ae65a1005d7e3cf4f0406 Mon Sep 17 00:00:00 2001 From: Minh Vu Date: Sat, 6 Jun 2026 16:39:30 +0200 Subject: [PATCH] Fix grouped linear FP8 calibration loop Signed-off-by: Minh Vu --- transformer_engine/pytorch/module/grouped_linear.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index ac304d3379..dcc0015015 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -599,11 +599,8 @@ def forward( if fp8_calibration: for i in range(num_gemms): - # amax of input - for i in range(num_gemms): - input_quantizers[i].calibrate(inputmats[i]) - for i in range(num_gemms): - weight_quantizers[i].calibrate(weights[i]) + input_quantizers[i].calibrate(inputmats[i]) + weight_quantizers[i].calibrate(weights[i]) if cpu_offloading: mark_not_offload(*weights_fp8, *weights)