Description
The AdjustMatmulOrder pass crashes with Check failed: shape_c.size() == 2 (3 vs. 2) when encountering a matmul chain where the intermediate result is 3D and the final weight is 2D.
This is a common pattern in transformer models: matmul(attn_output[B,S,D], W_o[D,D]) where the attention output is 3D (batched) and the output projection weight is 2D.
Reproducer
import numpy as np
import tvm
from tvm import relax
import tvm.relax.op as R
B, S, D = 2, 16, 64
bb = relax.BlockBuilder()
x = relax.Var('x', relax.TensorStructInfo((B, S, D), 'float32'))
wq = relax.Var('wq', relax.TensorStructInfo((D, D), 'float32'))
wk = relax.Var('wk', relax.TensorStructInfo((D, D), 'float32'))
wv = relax.Var('wv', relax.TensorStructInfo((D, D), 'float32'))
wo = relax.Var('wo', relax.TensorStructInfo((D, D), 'float32'))
with bb.function('main', [x, wq, wk, wv, wo]):
with bb.dataflow():
q = bb.emit(R.matmul(x, wq))
k = bb.emit(R.matmul(x, wk))
v = bb.emit(R.matmul(x, wv))
kt = bb.emit(R.permute_dims(k, [0, 2, 1]))
scores = bb.emit(R.matmul(q, kt))
scale = relax.const(1.0 / np.sqrt(D), 'float32')
scores = bb.emit(R.multiply(scores, scale))
attn = bb.emit(R.nn.softmax(scores, axis=-1))
out = bb.emit(R.matmul(attn, v)) # 3D result
proj = bb.emit_output(R.matmul(out, wo)) # 3D @ 2D → crash
bb.emit_func_output(proj)
mod = bb.finalize()
# This crashes:
pipeline = tvm.ir.transform.Sequential([
relax.transform.AdjustMatmulOrder(),
relax.transform.LegalizeOps()
])
mod_l = pipeline(mod) # Check failed: shape_c.size() == 2 (3 vs. 2)
Error
tvm.error.InternalError: Check failed: shape_c.size() == 2 (3 vs. 2) :
Root cause
The AdjustMatmulOrder pass assumes all operands in a matmul chain are 2D (shape.size() == 2). When the intermediate result of matmul(attn, v) produces a 3D tensor [B, S, D] that is then multiplied by a 2D weight [D, D], the pass fails the assertion.
Expected behavior
The pass should either handle mixed-dimension matmul chains (3D @ 2D) or skip them gracefully.
Environment
- TVM version: 0.24.dev0 (commit 0b0afd8, 2026-04-24)
Description
The
AdjustMatmulOrderpass crashes withCheck failed: shape_c.size() == 2 (3 vs. 2)when encountering a matmul chain where the intermediate result is 3D and the final weight is 2D.This is a common pattern in transformer models:
matmul(attn_output[B,S,D], W_o[D,D])where the attention output is 3D (batched) and the output projection weight is 2D.Reproducer
Error
Root cause
The
AdjustMatmulOrderpass assumes all operands in a matmul chain are 2D (shape.size() == 2). When the intermediate result ofmatmul(attn, v)produces a 3D tensor[B, S, D]that is then multiplied by a 2D weight[D, D], the pass fails the assertion.Expected behavior
The pass should either handle mixed-dimension matmul chains (3D @ 2D) or skip them gracefully.
Environment