Skip to content

[Bug] AdjustMatmulOrder crashes on batched matmul chains (3D @ 2D) #19576

@wuyii8941

Description

@wuyii8941

Description

The AdjustMatmulOrder pass crashes with Check failed: shape_c.size() == 2 (3 vs. 2) when encountering a matmul chain where the intermediate result is 3D and the final weight is 2D.

This is a common pattern in transformer models: matmul(attn_output[B,S,D], W_o[D,D]) where the attention output is 3D (batched) and the output projection weight is 2D.

Reproducer

import numpy as np
import tvm
from tvm import relax
import tvm.relax.op as R

B, S, D = 2, 16, 64
bb = relax.BlockBuilder()
x = relax.Var('x', relax.TensorStructInfo((B, S, D), 'float32'))
wq = relax.Var('wq', relax.TensorStructInfo((D, D), 'float32'))
wk = relax.Var('wk', relax.TensorStructInfo((D, D), 'float32'))
wv = relax.Var('wv', relax.TensorStructInfo((D, D), 'float32'))
wo = relax.Var('wo', relax.TensorStructInfo((D, D), 'float32'))
with bb.function('main', [x, wq, wk, wv, wo]):
    with bb.dataflow():
        q = bb.emit(R.matmul(x, wq))
        k = bb.emit(R.matmul(x, wk))
        v = bb.emit(R.matmul(x, wv))
        kt = bb.emit(R.permute_dims(k, [0, 2, 1]))
        scores = bb.emit(R.matmul(q, kt))
        scale = relax.const(1.0 / np.sqrt(D), 'float32')
        scores = bb.emit(R.multiply(scores, scale))
        attn = bb.emit(R.nn.softmax(scores, axis=-1))
        out = bb.emit(R.matmul(attn, v))         # 3D result
        proj = bb.emit_output(R.matmul(out, wo))  # 3D @ 2D → crash
    bb.emit_func_output(proj)
mod = bb.finalize()

# This crashes:
pipeline = tvm.ir.transform.Sequential([
    relax.transform.AdjustMatmulOrder(),
    relax.transform.LegalizeOps()
])
mod_l = pipeline(mod)  # Check failed: shape_c.size() == 2 (3 vs. 2)

Error

tvm.error.InternalError: Check failed: shape_c.size() == 2 (3 vs. 2) :

Root cause

The AdjustMatmulOrder pass assumes all operands in a matmul chain are 2D (shape.size() == 2). When the intermediate result of matmul(attn, v) produces a 3D tensor [B, S, D] that is then multiplied by a 2D weight [D, D], the pass fails the assertion.

Expected behavior

The pass should either handle mixed-dimension matmul chains (3D @ 2D) or skip them gracefully.

Environment

  • TVM version: 0.24.dev0 (commit 0b0afd8, 2026-04-24)

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions