Skip to content

Quantization support for GroupedTensor: FP8 per-tensor#3102

Open
int-smart wants to merge 3 commits into
NVIDIA:mainfrom
int-smart:feature/fp8_quant
Open

Quantization support for GroupedTensor: FP8 per-tensor#3102
int-smart wants to merge 3 commits into
NVIDIA:mainfrom
int-smart:feature/fp8_quant

Conversation

@int-smart
Copy link
Copy Markdown
Contributor

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes #2449

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

Kernels: Extended unary_kernel and unary_grad_kernel in vectorized_pointwise.h to dynamically support per-tensor scale, scale_inv, and amax for grouped tensors.
Alignment: Aligned the random padding in test_common.cu to a constant 64 elements to guarantee matching element offsets between input and output grouped tensors.
Verification: Corrected the FP8 cast validation loop in test_cast_fp8_grouped.cu to compare raw quantized values directly, resolving false test failures caused by rounding errors.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

int-smart added 2 commits June 6, 2026 11:55
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
@int-smart int-smart requested a review from Oleg-Goncharov as a code owner June 7, 2026 05:05
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 7, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 7, 2026

Greptile Summary

This PR adds FP8 per-tensor (delayed) scaling support for GroupedTensor quantize and dequantize operations by extending unary_kernel and unary_grad_kernel in vectorized_pointwise.h with new parameters for per-tensor scale, scale_inv, and amax arrays, and wiring up the NVTE_DELAYED_TENSOR_SCALING dispatch path in both the forward and backward quantize/dequantize helpers.

  • Kernel extension: Both unary_kernel and unary_grad_kernel gain optional offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, and amax_numel parameters; a binary search (find_tensor_id) maps each element's global index to its tensor ID for per-tensor scale/amax dispatch.
  • Test infrastructure: build_grouped_tensor is updated to collect and upload per-tensor scale and amax buffers; random padding is fixed to 64 elements to guarantee matching offsets between input and output grouped tensors.
  • New tests: test_cast_fp8_grouped.cu and test_dequantize_fp8_grouped.cu cover BF16/FP16/FP32→FP8 and FP8→BF16/FP16/FP32 grouped round-trips with varying shapes.

Confidence Score: 3/5

The kernel changes touch the hot path for every grouped FP8 quantization and are not safe to merge in their current form — both a silent stack overflow and a shared-memory correctness hazard need to be resolved first.

The block_max[64] fixed array in both unary_kernel and unary_grad_kernel will silently corrupt GPU stack data whenever a grouped tensor has more than 64 sub-tensors, a realistic scenario in large MoE workloads. Separately, calling reduce_max in a per-tensor loop without a __syncthreads() between iterations leaves the shared staging array in a racy state, meaning per-tensor amax values can be computed from stale data. Both defects affect correctness of the core amax tracking that drives delayed FP8 scaling. The test shapes used (3 tensors, small sizes) happen to avoid triggering either issue, so the tests pass while masking real production failures.

transformer_engine/common/util/vectorized_pointwise.h needs the most attention — both the block_max array bound and the reduce_max loop barrier are in this file. transformer_engine/common/cast/fp8/quantize_fp8.cuh warrants a second look on the backward overload's use of input for shape/offset derivation.

Important Files Changed

Filename Overview
transformer_engine/common/util/vectorized_pointwise.h Extended unary_kernel and unary_grad_kernel with per-tensor scale/amax/scale_inv support; introduces a fixed-size block_max[64] array (out-of-bounds when num_tensors>64) and a shared-memory race in the per-tensor amax reduction loop.
transformer_engine/common/cast/fp8/quantize_fp8.cuh Adds forward and backward group_quantize wrappers; backward overload derives N and offset arrays from the forward input tensor rather than the grad tensor, which may produce incorrect bounds when shapes differ.
transformer_engine/common/cast/fp8/dequantize_fp8.cuh Adds group_dequantize using the existing VectorizedUnaryKernelLauncher; scale_inv per-tensor dispatch looks correct.
transformer_engine/common/cast/dispatch/quantize.cuh Adds NVTE_DELAYED_TENSOR_SCALING dispatch cases for both forward and backward group quantize; straightforward plumbing.
tests/cpp/test_common.cu Collects per-tensor scale and amax from individual tensors and uploads them into the grouped buffer; random padding changed to a fixed 64-element alignment; straightforward bookkeeping.
tests/cpp/operator/test_cast_fp8_grouped.cu New test for grouped FP8 quantization; validates quantized values, scale_inv, and amax against CPU reference; copy-back logic is correct but tests only cover the forward path.
tests/cpp/operator/test_dequantize_fp8_grouped.cu New test for grouped FP8 dequantization; validates dequantized values against a CPU reference; logic appears correct.
transformer_engine/common/transformer_engine.cpp Relaxes the data-size validation from exact equality to >= to accommodate the padded grouped-tensor layout; intentional and correctly documented.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["nvte_group_quantize / nvte_group_dequantize"] --> B["group_quantize_fwd_helper / group_dequantize_helper"]
    B --> C{scaling_mode?}
    C -- NVTE_DELAYED_TENSOR_SCALING --> D["fp8::group_quantize / fp8::group_dequantize"]
    C -- NVTE_MXFP8_1D_SCALING --> E["mxfp8::group_quantize / mxfp8::group_dequantize"]
    D --> F["VectorizedUnaryKernelLauncher\n(offsets, first_dims, last_dims,\nnum_tensors, scale_numel, ...)"]
    F --> G["unary_kernel (GPU)\nfor each element:\n  find_tensor_id (binary search)\n  apply per-tensor scale\n  accumulate block_max[tensor_id]"]
    G --> H["per-tensor amax reduction\nreduce_max loop over num_tensors\natomicMaxFloat(amax[t])"]
    G --> I["scale_inv write\nat offsets[tensor_id] only"]
    style H fill:#f96,stroke:#c00
    style G fill:#f96,stroke:#c00
Loading

Comments Outside Diff (1)

  1. transformer_engine/common/cast/fp8/quantize_fp8.cuh, line 548-551 (link)

    P2 Backward group_quantize derives N and offset arrays from input, not grad

    N = product(input->data.shape) and the offsets/first_dims/last_dims pointers all come from input (the forward activation tensor), while the kernel iterates over grad.data.dptr using those bounds. If grad and input have different shapes or padding layouts — which is valid for some non-activation-backward call sites — the kernel will read out-of-bounds from grad or skip elements. These values should be derived from grad to match the actual data being quantized. Note also that input is obtained via convertNVTEGroupedTensor (no null check), so a null input passed by a caller would immediately fault here.

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

}
const int warp_id = threadIdx.x / THREADS_PER_WARP;

float block_max[64] = {0.0f};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Fixed-size block_max array overflows when num_tensors > 64

block_max is indexed by tensor_id which is bounded by num_tensors (a runtime kernel parameter), yet the array is fixed at size 64. When a grouped tensor has more than 64 sub-tensors — common in large MoE models — any thread whose tensor_id >= 64 writes past the end of the array, corrupting other stack/local variables (including warp_id, loop variables, max, etc.) and producing silently wrong quantized outputs or GPU faults. The same defect exists in the unary_grad_kernel at the corresponding location.

Comment on lines +282 to +290
if (offsets != nullptr || num_tensors > 1) {
for (size_t t = 0; t < num_tensors; ++t) {
float t_max = block_max[t];
t_max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(t_max, warp_id);
if (threadIdx.x == 0 && t_max > 0.0f) {
size_t amax_idx = (amax_numel == num_tensors) ? t : 0;
atomicMaxFloat(&amax[amax_idx], t_max);
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Shared-memory race in per-tensor amax loop

reduce_max uses a __shared__ float staging[num_warps] array and a single __syncthreads() that ensures visibility before warp 0 reads staging. However, reduce_max does NOT call __syncthreads() after warp 0's read before returning. Calling it in a loop means that warp 1 (and other non-zero warps) can reach the staging[warpid] = my_warp_max write for iteration t+1 before warp 0 finishes reading staging[1] for iteration t. Without an explicit barrier between iterations, the CUDA memory model does not guarantee ordering, so warp 0 can read a partially updated staging[1] and compute an incorrect per-tensor amax. A __syncthreads() is needed after each call to reduce_max in this loop (and the identical loop in unary_grad_kernel).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Quantization support for GroupedTensor: FP8 per-tensor

1 participant