diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index ac304d3379..dcc0015015 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -599,11 +599,8 @@ def forward( if fp8_calibration: for i in range(num_gemms): - # amax of input - for i in range(num_gemms): - input_quantizers[i].calibrate(inputmats[i]) - for i in range(num_gemms): - weight_quantizers[i].calibrate(weights[i]) + input_quantizers[i].calibrate(inputmats[i]) + weight_quantizers[i].calibrate(weights[i]) if cpu_offloading: mark_not_offload(*weights_fp8, *weights)