[Relax][ONNX] Fix Cast operator float->int NaN/Inf handling#19626
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the ONNX frontend in TVM Relax to handle casting from floating-point types to integer types by replacing NaN and Inf values with 0.0 before performing the cast. The reviewer suggested simplifying the check for non-finite values by using relax.op.logical_not(relax.op.isfinite(src)) instead of combining isnan and isinf with a logical OR, which reduces the number of operator calls and simplifies the generated Relax graph.
tlopex
left a comment
There was a problem hiding this comment.
Thanks for working on this. I don’t think this is ready to merge yet.
The new dynamic Cast path now applies to all integer destination dtypes, but int64/uint64 look broken. For bits == 64, the code builds constants like (1 << 64) - 1 and 1 << 63 with temp_dtype = "int64", which are not representable as int64 Relax constants. This means importing a non-constant ONNX Cast from float to INT64/UINT64 can fail in the frontend before codegen.
Could you please add coverage for FLOAT -> INT64 and FLOAT -> UINT64, and either handle the 64-bit case separately or avoid the wrapping path for those dtypes?
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces casting logic from float to integer types in the ONNX frontend, including handling of NaN/Inf values and proper integer wrapping behavior for different bit widths, along with corresponding unit tests. The review feedback suggests simplifying and unifying the integer wrapping logic for bit widths less than 64 by using bitwise_and with a mask, which avoids special-casing 32-bit integers and eliminates the need for floor_mod and runtime addition.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if bits == 32: | ||
| two_pow = relax.const(1 << bits, temp_dtype) | ||
| uw = relax.op.floor_mod(t, two_pow) | ||
| else: | ||
| mask_val = (1 << bits) - 1 | ||
| mask = relax.const(mask_val, temp_dtype) | ||
| uw = relax.op.bitwise_and(t, mask) | ||
| if signed: | ||
| half = 1 << (bits - 1) | ||
| half_c = relax.const(half, temp_dtype) | ||
| if bits == 32: | ||
| two_pow = relax.const(1 << bits, temp_dtype) | ||
| else: | ||
| two_pow = relax.op.add(mask, relax.const(1, temp_dtype)) | ||
| wrapped = relax.op.where( | ||
| relax.op.greater_equal(uw, half_c), | ||
| relax.op.subtract(uw, two_pow), | ||
| uw, | ||
| ) | ||
| else: | ||
| wrapped = uw | ||
| return relax.op.astype(wrapped, to_type) |
There was a problem hiding this comment.
The logic for wrapping integer values can be significantly simplified and unified.
By using bitwise_and with a mask of (1 << bits) - 1 for all bit widths less than 64 (including 32-bit), we can completely eliminate the special casing for bits == 32, avoid the expensive floor_mod operator, and remove the runtime addition relax.op.add(mask, 1).
Since temp_dtype is strictly larger than bits (i.e., int64 for 32-bit, and int32 for 8/16-bit), bitwise_and with the mask correctly handles both positive and negative values in two's complement representation.
mask_val = (1 << bits) - 1
mask = relax.const(mask_val, temp_dtype)
uw = relax.op.bitwise_and(t, mask)
if signed:
half = 1 << (bits - 1)
half_c = relax.const(half, temp_dtype)
two_pow = relax.const(1 << bits, temp_dtype)
wrapped = relax.op.where(
relax.op.greater_equal(uw, half_c),
relax.op.subtract(uw, two_pow),
uw,
)
else:
wrapped = uw
return relax.op.astype(wrapped, to_type)|
Hi @tlopex Thank you for the reminder. I have updated the section you mentioned. 😄 |
|
Thanks to @tlopex 😄 |
Hi Committers,
This PR is trying to fix issues #19542. Any suggestions would be appreciated if you are available.
Root cause:
FP to INT lowering can be implementation-defined or UB for NaN/Inf and extreme floats, producing backend-dependent results versus ONNX Runtime.
Solution:
Apply a minimal, deterministic frontend sanitization for float to integer Casts: map NaN and ±Inf to 0.0 before astype. This prevents NaN/Inf from reaching backend fptosi/fptoui lowers and yields stable behavior across targets.