[Fix] Stabilize layer_norm variance computation with two-pass reduction#19643
Conversation
This PR will fix apache#19592. LayerNorm could produce NaN on large-value, small-variance inputs due to catastrophic cancellation in var = E[x^2] - E[x]^2. Switch to a numerically stable two-pass formulation: - pass1 computes mean via sum(x) / N - pass2 computes variance via sum((x - mean)^2) / N
There was a problem hiding this comment.
Code Review
This pull request updates the layer normalization implementation in TVM TOPI to use a two-pass algorithm instead of a single-pass algorithm, which improves numerical stability by first computing the mean and then computing the variance based on the mean. The corresponding unit tests in Relax have been updated to reflect the new TIR structure. There are no review comments to evaluate, so no additional feedback is provided.
tlopex
left a comment
There was a problem hiding this comment.
Thanks, the updated legalization tests are useful and they do verify that LegalizeOps now emits the two-pass TIR structure.
My concern is that this PR fixes a numerical runtime bug, but the tests only check the generated IR structurally. They do not run the #19592 repro input, so they would not directly catch the original NaN failure.
Could you add a small numeric regression test using the issue input, e.g. [[80000.0, 80001.0, 80002.0, 80003.0]] with axis=-1, and check that the TVM output is finite and matches a stable reference / ONNXRuntime output? That would cover the actual failure mode in addition to the structural TIR change.
This PR will fix #19592.
LayerNorm could produce NaN on large-value, small-variance inputs due to catastrophic cancellation in var = E[x^2] - E[x]^2.
Switch to a numerically stable two-pass formulation: