-
Notifications
You must be signed in to change notification settings - Fork 613
[Common] MXFP8 kernel for grouped tensors #2586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Common] MXFP8 kernel for grouped tensors #2586
Conversation
e6bf02a to
fc2a53f
Compare
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
74a7917 to
88cf1b2
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
7c4fda7 to
39bb24f
Compare
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile Summary
Important Files Changed
Confidence score: 4/5
Sequence DiagramsequenceDiagram
participant User
participant TestSuite as "GroupedFusedCastMXFP8TestSuite"
participant Helper as "performTest<IType, OType>"
participant KernelDispatch as "group_quantize_*_helper"
participant Kernel as "group_quantize_mxfp8_kernel"
participant GPU as "CUDA Device"
User->>TestSuite: "Run test with parameters"
TestSuite->>TestSuite: "Configure tensors and validation"
TestSuite->>Helper: "performTest(processing_method, OP, ...)"
Helper->>Helper: "Setup input/output tensors and reference data"
Helper->>Helper: "Create NVTE grouped tensors"
alt Processing Method: CAST_ONLY
Helper->>KernelDispatch: "nvte_group_quantize(input, output)"
else Processing Method: CAST_DBIAS
Helper->>KernelDispatch: "nvte_group_quantize_dbias(grad, output, dbias)"
else Processing Method: CAST_ACT
Helper->>KernelDispatch: "nvte_group_gelu/silu/relu(input, output)"
else Processing Method: CAST_DACT
Helper->>KernelDispatch: "nvte_group_dgelu/dsilu/drelu(grad, input, output)"
else Processing Method: CAST_DBIAS_DACT
Helper->>KernelDispatch: "nvte_group_quantize_dbias_dgelu/dsilu/drelu(grad, input, output, dbias)"
end
KernelDispatch->>KernelDispatch: "group_quantize_fwd/bwd_helper"
KernelDispatch->>Kernel: "update_tma_descriptors<<<num_tensors, 32>>>"
GPU->>KernelDispatch: "TMA descriptors updated"
KernelDispatch->>Kernel: "group_quantize_mxfp8_kernel<<<blocks, threads>>>"
GPU->>Kernel: "Execute MXFP8 quantization with scaling"
Kernel->>GPU: "Write quantized data and scaling factors"
GPU->>KernelDispatch: "Kernel execution complete"
alt IS_DBIAS enabled
KernelDispatch->>GPU: "reduce_dbias(workspace, dbias)"
end
KernelDispatch->>Helper: "Quantized tensors ready"
Helper->>Helper: "Compare GPU results with CPU reference"
Helper->>TestSuite: "Validation complete"
TestSuite->>User: "Test result"
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 files reviewed, 6 comments
| case ShapeRepresentation::SAME_BOTH_DIMS: // rows_num = first_logical_dim / num_tensors; break; | ||
| case ShapeRepresentation::VARYING_LAST_DIM: | ||
| rows_num = first_logical_dim; | ||
| break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: commented-out code in case statement creates ambiguity
The commented calculation rows_num = first_logical_dim / num_tensors; break; suggests this case should compute rows differently, but the fallthrough to VARYING_LAST_DIM (which sets rows_num = first_logical_dim) may not be the intended behavior. If fallthrough is intentional, remove the comment to clarify. If the division by num_tensors is needed for SAME_BOTH_DIMS, uncomment and add the break statement.
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 files reviewed, 1 comment
| create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, | ||
| first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, | ||
| last_logical_dim, 0, output_type_bit_size); | ||
| } constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: missing newline after closing brace before constexpr declarations
This formatting issue breaks the visual separation between the conditional compilation block and the subsequent constant declarations. Add a newline for better readability.
| } constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; | |
| } | |
| constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Description
This PR adds a new kernel that supports MXFP8 quantization of grouped tensors.
Fixes # (issue)
Type of change
Changes
Checklist: