From 9b42f68fcf886426150eea051e707e4ad8c5d6ba Mon Sep 17 00:00:00 2001 From: hualxie Date: Fri, 5 Jun 2026 14:58:30 +0800 Subject: [PATCH 1/7] fix(optim): untie batched constant MatMul for OpenVINO GPU OpenVINO GPU's oneDNN gemm cannot select an implementation for a batched (rank >= 3) MatMul where an operand is a compile-time constant; the same gemm with a dynamic operand, and 2D constant gemm, both compile fine. Transformer disentangled-attention position terms (e.g. DeBERTa) fold to 3D constants and fail to compile with: [GPU] Failed to select implementation for ... type: gemm (compile_graph.cpp:59 selected_impl == nullptr) Add an EP-gated `untie-constant-batched-matmul` surgery that routes the constant operand through Add(const, zero), where zero is a data-dependent runtime [1] tensor (Cast -> Reshape(-1) -> Slice[0:1] -> Sub). This makes the operand runtime-valued so OV's constant folder cannot repack it into a gemm weight, while keeping the single batched MatMul (no perf regression) and leaving numerics unchanged (+0). Wired via autoconf: BatchedConstMatMulValidator detects the pattern and, gated to Intel IHV + GPU, emits a GraphOptimization opportunity the existing autoconf loop auto-applies. Pattern-based, architecture-agnostic. Also makes the model-validator device filter case-insensitive so builds that pass lowercase "gpu" are matched. --- .../analyze/core/information_engine.py | 1 + .../analyze/core/model_validators/__init__.py | 2 + .../batched_const_matmul_validator.py | 129 +++++++++++++++ .../model_validator_manager.py | 24 ++- .../modelkit/optim/capabilities/surgery.py | 17 ++ src/winml/modelkit/optim/pipes/surgery.py | 148 +++++++++++++++++- .../core/model_validators/test_validators.py | 60 +++++++ tests/unit/optim/pipes/test_pipe_surgery.py | 117 ++++++++++++++ 8 files changed, 492 insertions(+), 6 deletions(-) create mode 100644 src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py diff --git a/src/winml/modelkit/analyze/core/information_engine.py b/src/winml/modelkit/analyze/core/information_engine.py index d4c574b16..a3cfe1e1f 100644 --- a/src/winml/modelkit/analyze/core/information_engine.py +++ b/src/winml/modelkit/analyze/core/information_engine.py @@ -297,6 +297,7 @@ def _check_model(self) -> list[Information]: self._model, op_runtime_results=self._op_runtime_results, device=self._device, + ep=self._ep, ) manager_init_ms = int((time.perf_counter() - manager_init_start) * 1000) diff --git a/src/winml/modelkit/analyze/core/model_validators/__init__.py b/src/winml/modelkit/analyze/core/model_validators/__init__.py index 77cd7b924..4504269da 100644 --- a/src/winml/modelkit/analyze/core/model_validators/__init__.py +++ b/src/winml/modelkit/analyze/core/model_validators/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations from .base import ModelValidator +from .batched_const_matmul_validator import BatchedConstMatMulValidator from .constant_folding_validator import ConstantFoldingValidator from .dynamic_input_validator import DynamicInputValidator from .model_validator_manager import ModelValidatorManager @@ -21,6 +22,7 @@ __all__ = [ + "BatchedConstMatMulValidator", "ConstantFoldingValidator", "DynamicInputValidator", "ModelValidator", diff --git a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py new file mode 100644 index 000000000..9c9200546 --- /dev/null +++ b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py @@ -0,0 +1,129 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Validator for batched MatMul with a constant operand on OpenVINO GPU. + +OpenVINO GPU's oneDNN gemm cannot select an implementation for a batched +(rank >= 3) MatMul where an operand is a compile-time constant. The identical +gemm with a dynamic operand, and 2D constant gemm, both compile fine. Models +whose batched MatMul weights fold to constants (e.g. transformer disentangled +attention position terms) therefore fail to compile on OpenVINO GPU with: + + [GPU] Failed to select implementation for ... type: gemm + +This validator detects that structural pattern and recommends the +``untie-constant-batched-matmul`` surgery, which makes the constant operand +runtime-valued so gemm implementation selection succeeds. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from ...models.information import Action, ActionItem, ActionLevel, Information +from ...utils import infer_ihv_from_ep_name +from .base import ModelValidator + + +if TYPE_CHECKING: + from ...models.onnx_model import ONNXModel + from ...models.runtime_checks import PatternRuntime + +logger = logging.getLogger(__name__) + +# Surgery capability enabled when the pattern is detected (kebab-case to match +# the capability registry / autoconf normalization). +_SURGERY_FLAG = "untie-constant-batched-matmul" + + +class BatchedConstMatMulValidator(ModelValidator): + """Detect batched MatMul with a constant operand (OpenVINO GPU only).""" + + def __init__( + self, + model: ONNXModel, + op_runtime_results: list[PatternRuntime] | None = None, + ep: str | None = None, + device: str | None = None, + ) -> None: + super().__init__(model, op_runtime_results=op_runtime_results) + self.ep = ep + self.device = device + + @property + def validator_name(self) -> str: + """Name of this validator for logging/reporting.""" + return "BatchedConstMatMulValidator" + + @property + def pattern_id(self) -> str: + """Pattern ID for Information objects.""" + return "MODEL/BatchedConstantMatMul" + + def _is_enabled(self) -> bool: + """Only relevant for OpenVINO (Intel IHV) on GPU.""" + if (self.device or "").upper() != "GPU": + return False + if not self.ep: + return False + try: + from ...models.ihv_type import IHVType + + return infer_ihv_from_ep_name(self.ep) == IHVType.INTEL + except Exception: # pragma: no cover - defensive + return False + + def validate(self) -> Information | None: + """Detect batched MatMul with a single constant rank>=3 operand.""" + if not self._is_enabled(): + return None + + initializers = {init.name for init in self.graph.initializer} + rank_by_init = {init.name: len(init.dims) for init in self.graph.initializer} + + offenders: list[str] = [] + for node in self.graph.node: + if node.op_type != "MatMul" or len(node.input) != 2: + continue + const_inputs = [name for name in node.input if name in initializers] + # Exactly one constant operand (two-constant MatMuls fold away and + # never reach gemm impl selection). + if len(const_inputs) != 1: + continue + if rank_by_init.get(const_inputs[0], 0) >= 3: + offenders.append(node.name or const_inputs[0]) + + if not offenders: + return None + + examples = ", ".join(offenders[:3]) + action = Action( + pattern_from_id="", + pattern_to_id="", + level=ActionLevel.REQUIRED, + status=None, + action_items=[ + ActionItem(type="GraphOptimization", optimization_options={_SURGERY_FLAG: True}) + ], + details=( + "Enable untie-constant-batched-matmul surgery so the constant " + "operand becomes runtime-valued and OpenVINO GPU can select a " + "gemm implementation." + ), + ) + explanation = ( + f"Model contains {len(offenders)} batched MatMul(s) with a constant " + f"operand (examples: {examples}). OpenVINO GPU's oneDNN gemm cannot " + f"select an implementation for a batched MatMul with a constant " + f"operand, causing a '[GPU] Failed to select implementation ... gemm' " + f"compile failure. The untie-constant-batched-matmul surgery makes " + f"the operand runtime-valued without changing numerics." + ) + return Information( + explanation=explanation, + actions=[action], + pattern_id=self.pattern_id, + status=None, + ) diff --git a/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py b/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py index 29dd0235c..d658dbe53 100644 --- a/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py +++ b/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, ClassVar from ...utils.timing_utils import make_timing_logger +from .batched_const_matmul_validator import BatchedConstMatMulValidator from .constant_folding_validator import ConstantFoldingValidator from .dynamic_input_validator import DynamicInputValidator from .pattern_matching_validator import PatternMatchingValidator @@ -64,6 +65,11 @@ class ModelValidatorManager: "class": PatternMatchingValidator, "enabled_devices": None, # All devices }, + "batched_const_matmul": { + "class": BatchedConstMatMulValidator, + "enabled_devices": ["GPU"], # OpenVINO GPU gemm impl-selection issue + "needs_context": True, # validator self-gates on EP (Intel IHV) + }, } def __init__( @@ -72,6 +78,7 @@ def __init__( enabled_validators: list[str] | None = None, op_runtime_results: list[PatternRuntime] | None = None, device: str | None = None, + ep: str | None = None, ) -> None: """Initialize validator manager. @@ -92,6 +99,7 @@ def __init__( self.model_proto = model.get_model() self.op_runtime_results = op_runtime_results or [] self.device = device or "NPU" + self.ep = ep self.enabled_validators = enabled_validators or list(self.VALIDATORS.keys()) # Instantiate enabled validators @@ -102,18 +110,24 @@ def __init__( validator_class = validator_config["class"] enabled_devices = validator_config.get("enabled_devices") - # Check device constraint - if enabled_devices is not None and self.device not in enabled_devices: + # Check device constraint (case-insensitive: callers may pass + # "gpu" or "GPU" depending on the build/analyze entry point). + if enabled_devices is not None and (self.device or "").upper() not in { + d.upper() for d in enabled_devices + }: logger.info( f"Validator '{name}' is not enabled for device '{self.device}'. " f"Only enabled for: {enabled_devices}" ) continue + ctor_kwargs: dict = {"op_runtime_results": self.op_runtime_results} + if validator_config.get("needs_context"): + ctor_kwargs["ep"] = self.ep + ctor_kwargs["device"] = self.device + try: - self.validators.append( - validator_class(self.model, op_runtime_results=self.op_runtime_results) - ) + self.validators.append(validator_class(self.model, **ctor_kwargs)) logger.debug(f"Initialized validator: {name}") except Exception: logger.exception(f"Failed to initialize validator {name}") diff --git a/src/winml/modelkit/optim/capabilities/surgery.py b/src/winml/modelkit/optim/capabilities/surgery.py index 8b2048f00..0b6ec0768 100644 --- a/src/winml/modelkit/optim/capabilities/surgery.py +++ b/src/winml/modelkit/optim/capabilities/surgery.py @@ -37,3 +37,20 @@ category=CapabilityCategory.SURGERY, default=False, ) + +# Route a constant operand of a batched (rank >= 3) MatMul through a runtime +# no-op so it is no longer a compile-time constant. OpenVINO GPU's oneDNN gemm +# cannot select an implementation for a batched MatMul with a constant operand +# (e.g. transformer disentangled-attention position terms that fold to 3D +# constants); making the operand runtime-valued lets gemm impl selection +# succeed without changing numerics or splitting the batched op. +UNTIE_CONSTANT_BATCHED_MATMUL = BoolCapability( + name="untie-constant-batched-matmul", + ort_name=None, # Custom implementation, not ORT optimizer + description=( + "Make a batched MatMul's constant operand runtime-valued so OpenVINO " + "GPU can select a gemm implementation" + ), + category=CapabilityCategory.SURGERY, + default=False, +) diff --git a/src/winml/modelkit/optim/pipes/surgery.py b/src/winml/modelkit/optim/pipes/surgery.py index fa4fa6bcf..fed9b6ce6 100644 --- a/src/winml/modelkit/optim/pipes/surgery.py +++ b/src/winml/modelkit/optim/pipes/surgery.py @@ -38,6 +38,7 @@ SURGERY_CAPABILITIES: dict[str, Any] = caps_dict( surgery.CLAMP_CONSTANT_VALUES, surgery.REMOVE_ISNAN_IN_ATTENTION_MASK, + surgery.UNTIE_CONSTANT_BATCHED_MATMUL, ) @@ -57,6 +58,8 @@ class SurgeryPipeConfig(PipeConfig): fix_nan_attention_mask: Replace -inf attention mask with finite value and remove Softmax->IsNaN->Where NaN guard patterns mask_value: Replacement value for -inf (default: -1e3) + untie_constant_batched_matmul: Make a batched MatMul's constant operand + runtime-valued so OpenVINO GPU can select a gemm implementation verbose: Enable verbose logging """ @@ -64,6 +67,7 @@ class SurgeryPipeConfig(PipeConfig): clamp_min: float = -1e3 clamp_max: float = 1e3 remove_isnan_in_attention_mask: bool = False + untie_constant_batched_matmul: bool = False verbose: bool = False @@ -106,6 +110,7 @@ def build_config(cls, **kwargs: Any) -> SurgeryPipeConfig: clamp_min=kwargs.get("clamp_min", -1e3), clamp_max=kwargs.get("clamp_max", 1e3), remove_isnan_in_attention_mask=kwargs.get("remove_isnan_in_attention_mask", False), + untie_constant_batched_matmul=kwargs.get("untie_constant_batched_matmul", False), verbose=kwargs.get("verbose", False), ) @@ -119,7 +124,11 @@ def should_process(cls, config: SurgeryPipeConfig) -> bool: Returns: True if any surgery operation is enabled """ - return config.clamp_constant_values or config.remove_isnan_in_attention_mask + return ( + config.clamp_constant_values + or config.remove_isnan_in_attention_mask + or config.untie_constant_batched_matmul + ) def process(self, model: onnx.ModelProto, config: SurgeryPipeConfig) -> onnx.ModelProto: """Apply surgery operations to the model. @@ -149,6 +158,9 @@ def process(self, model: onnx.ModelProto, config: SurgeryPipeConfig) -> onnx.Mod if config.remove_isnan_in_attention_mask: model_copy = self._remove_isnan_in_attention_mask(model_copy, config.verbose) + if config.untie_constant_batched_matmul: + model_copy = self._untie_constant_batched_matmul(model_copy, config.verbose) + return model_copy def _clamp_constant_values( @@ -319,3 +331,137 @@ def _remove_isnan_in_attention_mask( ) return model + + # ----------------------------------------------------------------- + # untie-constant-batched-matmul + # ----------------------------------------------------------------- + + def _untie_constant_batched_matmul( + self, + model: onnx.ModelProto, + verbose: bool = False, + ) -> onnx.ModelProto: + """Make a batched MatMul's constant operand runtime-valued. + + OpenVINO GPU's oneDNN gemm cannot select an implementation for a batched + (rank >= 3) MatMul where an operand is a compile-time constant: the same + gemm with a dynamic operand, and 2D constant gemm, both compile fine. + Transformer disentangled-attention position terms depend only on weights, + so they fold into 3D constants and hit this case. + + Fix: route each such constant operand through ``Add(const, zero)`` where + ``zero = Sub(s, s)`` and ``s = ReduceMin(Cast(first_input -> float))``. + ``s`` is data-dependent, so OpenVINO's constant folder cannot collapse + the Add back into a packed gemm weight, yet ``+ 0`` leaves the values + unchanged and the single batched MatMul is preserved (no perf cost). + """ + from onnx import TensorProto, helper, numpy_helper + + graph = model.graph + initializers = {init.name: init for init in graph.initializer} + + # Collect (matmul_node, operand_index) where the operand is a constant + # initializer of rank >= 3. Skip MatMuls whose operands are all constant + # (those fold away entirely and never reach gemm impl selection). + targets: list[tuple[onnx.NodeProto, int]] = [] + for node in graph.node: + if node.op_type != "MatMul" or len(node.input) != 2: + continue + const_idx = [i for i, name in enumerate(node.input) if name in initializers] + if len(const_idx) != 1: + continue + idx = const_idx[0] + if len(initializers[node.input[idx]].dims) >= 3: + targets.append((node, idx)) + + if not targets: + return model + + if not graph.input: + logger.warning( + "SurgeryPipe: untie-constant-batched-matmul: no graph input to " + "derive a runtime value from; skipping %d MatMul(s)", + len(targets), + ) + return model + + prefix = "winml_ovgpu_untie" + first_input = graph.input[0].name + new_nodes: list[onnx.NodeProto] = [] + new_inits: list[onnx.TensorProto] = [] + + # Build a shape-[1] runtime zero from input *data* (not shape — input + # shapes are static and would be folded). Only ubiquitous ops are used + # so the static analyzer handles them: a single input element is sliced + # out and subtracted from itself. A [1] tensor broadcasts against any + # constant operand, regardless of its rank. + xf = f"{prefix}_xf" + new_nodes.append( + helper.make_node("Cast", [first_input], [xf], to=TensorProto.FLOAT, name=xf) + ) + flat = f"{prefix}_flat" + new_inits.append(numpy_helper.from_array(np.array([-1], dtype=np.int64), f"{prefix}_m1")) + new_nodes.append(helper.make_node("Reshape", [xf, f"{prefix}_m1"], [flat], name=flat)) + elem = f"{prefix}_elem" + new_inits.append(numpy_helper.from_array(np.array([0], dtype=np.int64), f"{prefix}_0")) + new_inits.append(numpy_helper.from_array(np.array([1], dtype=np.int64), f"{prefix}_1")) + new_nodes.append( + helper.make_node( + "Slice", [flat, f"{prefix}_0", f"{prefix}_1", f"{prefix}_0"], [elem], name=elem + ) + ) + # zero = elem - elem == 0.0 (data-dependent, so it is not folded away). + zero_f32 = f"{prefix}_zero_f32" + new_nodes.append(helper.make_node("Sub", [elem, elem], [zero_f32], name=zero_f32)) + + # A zero must match each operand's dtype (ONNX has no implicit promotion). + zero_by_dtype: dict[int, str] = {int(TensorProto.FLOAT): zero_f32} + + def zero_for(dtype: int) -> str: + name = zero_by_dtype.get(dtype) + if name is None: + name = f"{prefix}_zero_{dtype}" + new_nodes.append( + helper.make_node("Cast", [zero_f32], [name], to=dtype, name=name) + ) + zero_by_dtype[dtype] = name + return name + + untied = 0 + for node, idx in targets: + const_name = node.input[idx] + dtype = initializers[const_name].data_type + if dtype not in (TensorProto.FLOAT, TensorProto.FLOAT16, TensorProto.DOUBLE): + continue + dyn = f"{prefix}_{node.name}_in{idx}".replace("/", "_") + new_nodes.append( + helper.make_node("Add", [const_name, zero_for(dtype)], [dyn], name=dyn) + ) + node.input[idx] = dyn + untied += 1 + if verbose: + logger.info( + " untie-constant-batched-matmul: %s input[%d] %s -> %s", + node.name, + idx, + const_name, + dyn, + ) + + if untied == 0: + return model + + graph.initializer.extend(new_inits) + # Prepend new nodes: their inputs are only graph inputs / initializers, + # so placing them first keeps the graph topologically sorted. + existing = list(graph.node) + del graph.node[:] + graph.node.extend(new_nodes + existing) + + logger.info( + "SurgeryPipe: untie-constant-batched-matmul: untied %d batched " + "MatMul constant operand(s)", + untied, + ) + + return model diff --git a/tests/unit/analyze/core/model_validators/test_validators.py b/tests/unit/analyze/core/model_validators/test_validators.py index ddc83c7be..787224aa6 100644 --- a/tests/unit/analyze/core/model_validators/test_validators.py +++ b/tests/unit/analyze/core/model_validators/test_validators.py @@ -379,3 +379,63 @@ def test_unknown_validator_logs_warning(self, caplog): assert len(manager.validators) == 1 assert manager.validators[0].validator_name == "ConstantFoldingValidator" assert "Unknown validator" in caplog.text + + +def _make_batched_const_matmul_proto(const_rank: int = 3): + """Model: data [2,3,4] @ W(const) [2,4,5] -> out [2,3,5].""" + import numpy as np + from onnx import numpy_helper + + w_shape = [2, 4, 5] if const_rank == 3 else [4, 5] + w = numpy_helper.from_array(np.zeros(w_shape, dtype=np.float32), "W") + matmul = helper.make_node("MatMul", ["data", "W"], ["out"], name="batched_matmul") + graph = helper.make_graph( + [matmul], + "batched_const_matmul", + [helper.make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4])], + [helper.make_tensor_value_info("out", TensorProto.FLOAT, [2, 3, 5])], + initializer=[w], + ) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + +class TestBatchedConstMatMulValidator: + """OpenVINO-GPU batched constant MatMul detector.""" + + def _validate(self, proto, ep, device): + from winml.modelkit.analyze.core.model_validators import BatchedConstMatMulValidator + + model = create_onnx_model_wrapper(proto) + return BatchedConstMatMulValidator(model, ep=ep, device=device).validate() + + def test_detects_for_openvino_gpu(self): + """Emits a GraphOptimization action enabling the surgery for OV GPU.""" + info = self._validate(_make_batched_const_matmul_proto(), "openvino", "GPU") + assert info is not None + assert info.pattern_id == "MODEL/BatchedConstantMatMul" + items = info.actions[0].action_items + assert items[0].type == "GraphOptimization" + assert items[0].optimization_options == {"untie-constant-batched-matmul": True} + + def test_skipped_for_openvino_npu(self): + """Device-gated: NPU is unaffected.""" + assert self._validate(_make_batched_const_matmul_proto(), "openvino", "NPU") is None + + def test_skipped_for_non_intel_gpu(self): + """IHV-gated: a non-Intel GPU EP is unaffected.""" + info = self._validate(_make_batched_const_matmul_proto(), "DmlExecutionProvider", "GPU") + assert info is None + + def test_skipped_for_two_dim_constant(self): + """Rank-2 constant gemm compiles on OV GPU; not flagged.""" + info = self._validate(_make_batched_const_matmul_proto(const_rank=2), "openvino", "GPU") + assert info is None + + def test_manager_wires_validator_for_openvino_gpu(self): + """Manager constructs the validator and surfaces the action for OV GPU.""" + model = create_onnx_model_wrapper(_make_batched_const_matmul_proto()) + manager = ModelValidatorManager(model, device="GPU", ep="openvino") + names = [v.validator_name for v in manager.validators] + assert "BatchedConstMatMulValidator" in names + infos = manager.run_all_validators() + assert any(i.pattern_id == "MODEL/BatchedConstantMatMul" for i in infos) diff --git a/tests/unit/optim/pipes/test_pipe_surgery.py b/tests/unit/optim/pipes/test_pipe_surgery.py index 6ab7e0034..fb98763d3 100644 --- a/tests/unit/optim/pipes/test_pipe_surgery.py +++ b/tests/unit/optim/pipes/test_pipe_surgery.py @@ -342,3 +342,120 @@ def test_surgery_pipe_runs_last(self) -> None: # SurgeryPipe should be last in the list assert PIPES[-1].name == "surgery" + + +# ============================================================================= +# UNTIE-CONSTANT-BATCHED-MATMUL TESTS +# ============================================================================= + + +def _make_batched_const_matmul_model( + *, + const_rank: int = 3, + const_on_rhs: bool = True, +) -> onnx.ModelProto: + """Build a model with a batched MatMul that has one constant operand. + + data [2,3,4] @ W(const) [2,4,5] -> out [2,3,5] (const on rhs), or the + transposed arrangement when ``const_on_rhs`` is False. + """ + from onnx import TensorProto, helper + + rng = np.random.RandomState(0) + if const_on_rhs: + data_shape, w_shape, out_shape = [2, 3, 4], [2, 4, 5], [2, 3, 5] + mm_inputs = ["data", "W"] + else: + data_shape, w_shape, out_shape = [2, 4, 5], [2, 3, 4], [2, 3, 5] + mm_inputs = ["W", "data"] + + if const_rank == 2: + w_shape = w_shape[1:] + + w = numpy_helper.from_array(rng.randn(*w_shape).astype(np.float32), "W") + matmul = helper.make_node("MatMul", mm_inputs, ["out"], name="batched_matmul") + graph = helper.make_graph( + [matmul], + "test_batched_const_matmul", + [helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape)], + [helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape)], + initializer=[w], + ) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + +class TestUntieConstantBatchedMatmulCapability: + """Capability/config plumbing for untie-constant-batched-matmul.""" + + def test_capability_exists(self) -> None: + """Capability is registered with a None ort_name (custom impl).""" + assert "untie-constant-batched-matmul" in SURGERY_CAPABILITIES + assert SURGERY_CAPABILITIES["untie-constant-batched-matmul"].ort_name is None + + def test_build_config_enable_via_kwarg(self) -> None: + """Flag can be toggled through build_config.""" + config = SurgeryPipe.build_config(untie_constant_batched_matmul=True) + assert config.untie_constant_batched_matmul is True + + def test_should_process_true_when_enabled(self) -> None: + """should_process is True when only this surgery is enabled.""" + config = SurgeryPipeConfig(untie_constant_batched_matmul=True) + assert SurgeryPipe.should_process(config) is True + + +class TestUntieConstantBatchedMatmulProcess: + """Graph transform behavior.""" + + def test_constant_operand_becomes_runtime_valued(self) -> None: + """The MatMul no longer consumes the initializer directly.""" + model = _make_batched_const_matmul_model() + result = SurgeryPipe().process( + model, SurgeryPipeConfig(untie_constant_batched_matmul=True) + ) + + matmul = next(n for n in result.graph.node if n.op_type == "MatMul") + initializer_names = {init.name for init in result.graph.initializer} + # No MatMul input is a direct initializer anymore. + assert not (set(matmul.input) & initializer_names) + # An Add node now produces the (formerly constant) operand. + add_nodes = [n for n in result.graph.node if n.op_type == "Add"] + assert len(add_nodes) == 1 + assert add_nodes[0].output[0] in matmul.input + # Graph remains structurally valid. + onnx.checker.check_model(result) + + def test_numerics_unchanged(self) -> None: + """+0 tie leaves outputs bit-for-bit identical on ORT CPU.""" + import onnxruntime as ort + + model = _make_batched_const_matmul_model() + transformed = SurgeryPipe().process( + model, SurgeryPipeConfig(untie_constant_batched_matmul=True) + ) + + rng = np.random.RandomState(7) + feed = {"data": rng.randn(2, 3, 4).astype(np.float32)} + + ref = ort.InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ).run(None, feed)[0] + got = ort.InferenceSession( + transformed.SerializeToString(), providers=["CPUExecutionProvider"] + ).run(None, feed)[0] + np.testing.assert_array_equal(ref, got) + + def test_two_dim_constant_is_left_untouched(self) -> None: + """Rank-2 constant gemm compiles on OV GPU, so it must not be rewritten.""" + model = _make_batched_const_matmul_model(const_rank=2) + result = SurgeryPipe().process( + model, SurgeryPipeConfig(untie_constant_batched_matmul=True) + ) + assert not any(n.op_type == "Add" for n in result.graph.node) + + def test_constant_on_lhs_is_handled(self) -> None: + """A constant rank-3 operand on the LHS is untied too.""" + model = _make_batched_const_matmul_model(const_on_rhs=False) + result = SurgeryPipe().process( + model, SurgeryPipeConfig(untie_constant_batched_matmul=True) + ) + assert any(n.op_type == "Add" for n in result.graph.node) From 5d04ce0f2e0c0c57a508e8e72a606f8245c49ec8 Mon Sep 17 00:00:00 2001 From: hualxie Date: Mon, 8 Jun 2026 16:46:03 +0800 Subject: [PATCH 2/7] update --- .../core/model_validators/batched_const_matmul_validator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py index 9c9200546..188a4bcac 100644 --- a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py +++ b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py @@ -113,13 +113,15 @@ def validate(self) -> Information | None: "gemm implementation." ), ) + # https://github.com/openvinotoolkit/openvino/issues/36272 explanation = ( f"Model contains {len(offenders)} batched MatMul(s) with a constant " f"operand (examples: {examples}). OpenVINO GPU's oneDNN gemm cannot " f"select an implementation for a batched MatMul with a constant " f"operand, causing a '[GPU] Failed to select implementation ... gemm' " f"compile failure. The untie-constant-batched-matmul surgery makes " - f"the operand runtime-valued without changing numerics." + f"the operand runtime-valued without changing numerics. " + f"It is fixed in openvino==2026.2.0, so no need to apply the surgery if using that version or later." ) return Information( explanation=explanation, From d9f5ca760195107a041fdc2bcd79f72a5ce756dc Mon Sep 17 00:00:00 2001 From: hualxie Date: Mon, 8 Jun 2026 17:04:37 +0800 Subject: [PATCH 3/7] use EPName --- .../core/model_validators/batched_const_matmul_validator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py index 188a4bcac..62825dfc2 100644 --- a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py +++ b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: from ...models.onnx_model import ONNXModel from ...models.runtime_checks import PatternRuntime + from ....utils.constants import EPName logger = logging.getLogger(__name__) @@ -45,7 +46,7 @@ def __init__( self, model: ONNXModel, op_runtime_results: list[PatternRuntime] | None = None, - ep: str | None = None, + ep: EPName | None = None, device: str | None = None, ) -> None: super().__init__(model, op_runtime_results=op_runtime_results) @@ -121,7 +122,8 @@ def validate(self) -> Information | None: f"operand, causing a '[GPU] Failed to select implementation ... gemm' " f"compile failure. The untie-constant-batched-matmul surgery makes " f"the operand runtime-valued without changing numerics. " - f"It is fixed in openvino==2026.2.0, so no need to apply the surgery if using that version or later." + f"It is fixed in openvino==2026.2.0, so no need to apply the surgery " + f"if using that version or later." ) return Information( explanation=explanation, From 9dbfb6fad609c1ac9a80f10f3d1a9d586c66f7ef Mon Sep 17 00:00:00 2001 From: hualxie Date: Tue, 9 Jun 2026 09:42:13 +0800 Subject: [PATCH 4/7] sort --- .../core/model_validators/batched_const_matmul_validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py index 62825dfc2..65ea953cd 100644 --- a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py +++ b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py @@ -28,9 +28,9 @@ if TYPE_CHECKING: + from ....utils.constants import EPName from ...models.onnx_model import ONNXModel from ...models.runtime_checks import PatternRuntime - from ....utils.constants import EPName logger = logging.getLogger(__name__) From 79bcc3f5ec6473145cd249ab11b4430975888fef Mon Sep 17 00:00:00 2001 From: hualxie Date: Wed, 10 Jun 2026 10:39:27 +0800 Subject: [PATCH 5/7] fix(optim): address review comments for untie batched constant MatMul - Use loop index for the untied operand name instead of node.name, which is optional in ONNX and can be blank/duplicated (would collide and yield an invalid graph). - Update docstring to describe the actual Cast/Reshape/Slice/Sub construction (was stale ReduceMin wording) and document the non-empty-first-input assumption. - Split the Slice starts/ends/axes initializers into distinct named tensors. - Note the Constant-node detection gap in the validator (shared with surgery). - Add a test for two unnamed batched-const MatMuls (name-collision regression). --- .../batched_const_matmul_validator.py | 6 ++ src/winml/modelkit/optim/pipes/surgery.py | 39 ++++++++----- tests/unit/optim/pipes/test_pipe_surgery.py | 56 ++++++++++++++++--- 3 files changed, 78 insertions(+), 23 deletions(-) diff --git a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py index 65ea953cd..d13b7f5ce 100644 --- a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py +++ b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py @@ -81,6 +81,12 @@ def validate(self) -> Information | None: if not self._is_enabled(): return None + # Known gap: constants expressed as `Constant` op nodes (rather than + # graph initializers) are not detected here. The `untie-constant-batched + # -matmul` surgery in surgery.py has the same limitation, so detection + # and surgery stay consistent. Most exporters and ORT preprocessing emit + # weights as initializers, so this covers the disentangled-attention case + # in practice; `Constant`-node weights would need handling on both sides. initializers = {init.name for init in self.graph.initializer} rank_by_init = {init.name: len(init.dims) for init in self.graph.initializer} diff --git a/src/winml/modelkit/optim/pipes/surgery.py b/src/winml/modelkit/optim/pipes/surgery.py index fed9b6ce6..8250a19ad 100644 --- a/src/winml/modelkit/optim/pipes/surgery.py +++ b/src/winml/modelkit/optim/pipes/surgery.py @@ -350,10 +350,17 @@ def _untie_constant_batched_matmul( so they fold into 3D constants and hit this case. Fix: route each such constant operand through ``Add(const, zero)`` where - ``zero = Sub(s, s)`` and ``s = ReduceMin(Cast(first_input -> float))``. - ``s`` is data-dependent, so OpenVINO's constant folder cannot collapse + ``zero`` is a runtime ``[1]`` tensor built from the first graph input's + *data*: ``Cast(first_input -> float) -> Reshape([-1]) -> Slice([0:1])`` + yields a single element ``elem``, and ``zero = Sub(elem, elem) == 0.0``. + ``zero`` is data-dependent, so OpenVINO's constant folder cannot collapse the Add back into a packed gemm weight, yet ``+ 0`` leaves the values unchanged and the single batched MatMul is preserved (no perf cost). + + Assumption: the first graph input has at least one element at runtime. + The ``Slice([0:1])`` is out of bounds for a zero-sized input (e.g. a + dynamic batch dimension fed an empty batch), which would raise at + inference time rather than produce a zero. """ from onnx import TensorProto, helper, numpy_helper @@ -403,13 +410,16 @@ def _untie_constant_batched_matmul( new_inits.append(numpy_helper.from_array(np.array([-1], dtype=np.int64), f"{prefix}_m1")) new_nodes.append(helper.make_node("Reshape", [xf, f"{prefix}_m1"], [flat], name=flat)) elem = f"{prefix}_elem" - new_inits.append(numpy_helper.from_array(np.array([0], dtype=np.int64), f"{prefix}_0")) - new_inits.append(numpy_helper.from_array(np.array([1], dtype=np.int64), f"{prefix}_1")) - new_nodes.append( - helper.make_node( - "Slice", [flat, f"{prefix}_0", f"{prefix}_1", f"{prefix}_0"], [elem], name=elem - ) - ) + # Slice(flat, starts=[0], ends=[1], axes=[0]) -> the first element. + # starts and axes are distinct tensors even though both hold [0], so a + # future edit to one role cannot silently corrupt the other. + starts = f"{prefix}_slice_starts" + ends = f"{prefix}_slice_ends" + axis = f"{prefix}_slice_axis" + new_inits.append(numpy_helper.from_array(np.array([0], dtype=np.int64), starts)) + new_inits.append(numpy_helper.from_array(np.array([1], dtype=np.int64), ends)) + new_inits.append(numpy_helper.from_array(np.array([0], dtype=np.int64), axis)) + new_nodes.append(helper.make_node("Slice", [flat, starts, ends, axis], [elem], name=elem)) # zero = elem - elem == 0.0 (data-dependent, so it is not folded away). zero_f32 = f"{prefix}_zero_f32" new_nodes.append(helper.make_node("Sub", [elem, elem], [zero_f32], name=zero_f32)) @@ -421,19 +431,20 @@ def zero_for(dtype: int) -> str: name = zero_by_dtype.get(dtype) if name is None: name = f"{prefix}_zero_{dtype}" - new_nodes.append( - helper.make_node("Cast", [zero_f32], [name], to=dtype, name=name) - ) + new_nodes.append(helper.make_node("Cast", [zero_f32], [name], to=dtype, name=name)) zero_by_dtype[dtype] = name return name untied = 0 - for node, idx in targets: + # Index the loop rather than node.name: node names are optional in ONNX + # and exporters routinely leave them blank or duplicated, so deriving + # `dyn` from the name would collide and produce an invalid graph. + for untie_idx, (node, idx) in enumerate(targets): const_name = node.input[idx] dtype = initializers[const_name].data_type if dtype not in (TensorProto.FLOAT, TensorProto.FLOAT16, TensorProto.DOUBLE): continue - dyn = f"{prefix}_{node.name}_in{idx}".replace("/", "_") + dyn = f"{prefix}_untied{untie_idx}_in{idx}" new_nodes.append( helper.make_node("Add", [const_name, zero_for(dtype)], [dyn], name=dyn) ) diff --git a/tests/unit/optim/pipes/test_pipe_surgery.py b/tests/unit/optim/pipes/test_pipe_surgery.py index fb98763d3..16df08d79 100644 --- a/tests/unit/optim/pipes/test_pipe_surgery.py +++ b/tests/unit/optim/pipes/test_pipe_surgery.py @@ -409,9 +409,7 @@ class TestUntieConstantBatchedMatmulProcess: def test_constant_operand_becomes_runtime_valued(self) -> None: """The MatMul no longer consumes the initializer directly.""" model = _make_batched_const_matmul_model() - result = SurgeryPipe().process( - model, SurgeryPipeConfig(untie_constant_batched_matmul=True) - ) + result = SurgeryPipe().process(model, SurgeryPipeConfig(untie_constant_batched_matmul=True)) matmul = next(n for n in result.graph.node if n.op_type == "MatMul") initializer_names = {init.name for init in result.graph.initializer} @@ -447,15 +445,55 @@ def test_numerics_unchanged(self) -> None: def test_two_dim_constant_is_left_untouched(self) -> None: """Rank-2 constant gemm compiles on OV GPU, so it must not be rewritten.""" model = _make_batched_const_matmul_model(const_rank=2) - result = SurgeryPipe().process( - model, SurgeryPipeConfig(untie_constant_batched_matmul=True) - ) + result = SurgeryPipe().process(model, SurgeryPipeConfig(untie_constant_batched_matmul=True)) assert not any(n.op_type == "Add" for n in result.graph.node) def test_constant_on_lhs_is_handled(self) -> None: """A constant rank-3 operand on the LHS is untied too.""" model = _make_batched_const_matmul_model(const_on_rhs=False) - result = SurgeryPipe().process( - model, SurgeryPipeConfig(untie_constant_batched_matmul=True) - ) + result = SurgeryPipe().process(model, SurgeryPipeConfig(untie_constant_batched_matmul=True)) assert any(n.op_type == "Add" for n in result.graph.node) + + def test_duplicate_node_names_do_not_collide(self) -> None: + """Two target MatMuls with empty names produce a valid graph. + + Node names are optional in ONNX; exporters routinely leave them blank. + The generated dynamic-operand names must be unique regardless, or the + transformed graph would have colliding tensor names and fail validation. + """ + from onnx import TensorProto, helper + + rng = np.random.RandomState(0) + w1 = numpy_helper.from_array(rng.randn(2, 4, 5).astype(np.float32), "W1") + w2 = numpy_helper.from_array(rng.randn(2, 5, 6).astype(np.float32), "W2") + # Both MatMuls deliberately left unnamed (name=""). + mm1 = helper.make_node("MatMul", ["data", "W1"], ["mid"], name="") + mm2 = helper.make_node("MatMul", ["mid", "W2"], ["out"], name="") + graph = helper.make_graph( + [mm1, mm2], + "test_dup_names", + [helper.make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4])], + [helper.make_tensor_value_info("out", TensorProto.FLOAT, [2, 3, 6])], + initializer=[w1, w2], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + result = SurgeryPipe().process(model, SurgeryPipeConfig(untie_constant_batched_matmul=True)) + + # Both constants are untied and the graph stays structurally valid. + add_nodes = [n for n in result.graph.node if n.op_type == "Add"] + assert len(add_nodes) == 2 + assert len({n.output[0] for n in add_nodes}) == 2 + onnx.checker.check_model(result) + + # Numerics are unchanged versus the original model. + import onnxruntime as ort + + feed = {"data": np.random.RandomState(7).randn(2, 3, 4).astype(np.float32)} + ref = ort.InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ).run(None, feed)[0] + got = ort.InferenceSession( + result.SerializeToString(), providers=["CPUExecutionProvider"] + ).run(None, feed)[0] + np.testing.assert_array_equal(ref, got) From 06a04c971bc444248131651d4594ba1bc5f1909a Mon Sep 17 00:00:00 2001 From: xieofxie Date: Fri, 12 Jun 2026 11:43:37 +0800 Subject: [PATCH 6/7] remove needs_context --- .../analyze/core/model_validators/base.py | 13 ++++++++++- .../batched_const_matmul_validator.py | 22 +++---------------- .../model_validator_manager.py | 10 ++++----- .../pattern_matching_validator.py | 19 +--------------- 4 files changed, 21 insertions(+), 43 deletions(-) diff --git a/src/winml/modelkit/analyze/core/model_validators/base.py b/src/winml/modelkit/analyze/core/model_validators/base.py index cd0432253..226078957 100644 --- a/src/winml/modelkit/analyze/core/model_validators/base.py +++ b/src/winml/modelkit/analyze/core/model_validators/base.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: + from ....utils.constants import EPName from ...models.information import Information from ...models.onnx_model import ONNXModel from ...models.runtime_checks import PatternRuntime @@ -34,19 +35,26 @@ class ModelValidator(ABC): model_proto: ONNX ModelProto extracted from model graph: Shorthand for model_proto.graph op_runtime_results: List of PatternRuntime results from runtime checker (optional) + ep: Execution provider name (optional) + device: Device type, e.g. "NPU", "GPU", "CPU" (optional) """ def __init__( self, model: ONNXModel, op_runtime_results: list[PatternRuntime] | None = None, + ep: EPName | None = None, + device: str | None = None, ) -> None: - """Initialize validator with ONNX model and optional runtime results. + """Initialize validator with ONNX model and optional context. Args: model: ONNXModel wrapper to validate op_runtime_results: List of PatternRuntime results from runtime checker. Used to enrich validators with OP-level information. + ep: Execution provider name. Validators that gate on EP read this. + device: Device type (e.g., "NPU", "GPU", "CPU"). Validators that gate + on device read this. Raises: ValueError: If model is invalid @@ -55,6 +63,9 @@ def __init__( self.model_proto = model.get_model() self.graph = self.model_proto.graph self.op_runtime_results = op_runtime_results or [] + # Annotate explicitly so the EPName Literal is not widened to ``str``. + self.ep: EPName | None = ep + self.device = device logger.debug( f"Initialized {self.validator_name} for model with {len(self.graph.node)} nodes" diff --git a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py index d13b7f5ce..91a3114c1 100644 --- a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py +++ b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py @@ -20,18 +20,12 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING from ...models.information import Action, ActionItem, ActionLevel, Information from ...utils import infer_ihv_from_ep_name from .base import ModelValidator -if TYPE_CHECKING: - from ....utils.constants import EPName - from ...models.onnx_model import ONNXModel - from ...models.runtime_checks import PatternRuntime - logger = logging.getLogger(__name__) # Surgery capability enabled when the pattern is detected (kebab-case to match @@ -42,17 +36,6 @@ class BatchedConstMatMulValidator(ModelValidator): """Detect batched MatMul with a constant operand (OpenVINO GPU only).""" - def __init__( - self, - model: ONNXModel, - op_runtime_results: list[PatternRuntime] | None = None, - ep: EPName | None = None, - device: str | None = None, - ) -> None: - super().__init__(model, op_runtime_results=op_runtime_results) - self.ep = ep - self.device = device - @property def validator_name(self) -> str: """Name of this validator for logging/reporting.""" @@ -67,12 +50,13 @@ def _is_enabled(self) -> bool: """Only relevant for OpenVINO (Intel IHV) on GPU.""" if (self.device or "").upper() != "GPU": return False - if not self.ep: + ep = self.ep + if not ep: return False try: from ...models.ihv_type import IHVType - return infer_ihv_from_ep_name(self.ep) == IHVType.INTEL + return infer_ihv_from_ep_name(ep) == IHVType.INTEL except Exception: # pragma: no cover - defensive return False diff --git a/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py b/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py index d658dbe53..6b240e983 100644 --- a/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py +++ b/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py @@ -68,7 +68,6 @@ class ModelValidatorManager: "batched_const_matmul": { "class": BatchedConstMatMulValidator, "enabled_devices": ["GPU"], # OpenVINO GPU gemm impl-selection issue - "needs_context": True, # validator self-gates on EP (Intel IHV) }, } @@ -121,10 +120,11 @@ def __init__( ) continue - ctor_kwargs: dict = {"op_runtime_results": self.op_runtime_results} - if validator_config.get("needs_context"): - ctor_kwargs["ep"] = self.ep - ctor_kwargs["device"] = self.device + ctor_kwargs: dict = { + "op_runtime_results": self.op_runtime_results, + "ep": self.ep, + "device": self.device, + } try: self.validators.append(validator_class(self.model, **ctor_kwargs)) diff --git a/src/winml/modelkit/analyze/core/model_validators/pattern_matching_validator.py b/src/winml/modelkit/analyze/core/model_validators/pattern_matching_validator.py index afdb59a2f..8d0b5aa17 100644 --- a/src/winml/modelkit/analyze/core/model_validators/pattern_matching_validator.py +++ b/src/winml/modelkit/analyze/core/model_validators/pattern_matching_validator.py @@ -12,17 +12,13 @@ import json import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar +from typing import ClassVar from ...models import ModelTag from ...models.information import Action, ActionLevel, Information from .base import ModelValidator -if TYPE_CHECKING: - from ...models.onnx_model import ONNXModel - from ...models.runtime_checks import PatternRuntime - logger = logging.getLogger(__name__) @@ -95,19 +91,6 @@ class PatternMatchingValidator(ModelValidator): ), ] - def __init__( - self, - model: ONNXModel, - op_runtime_results: list[PatternRuntime] | None = None, - ) -> None: - """Initialize validator. - - Args: - model: ONNXModel wrapper to validate - op_runtime_results: List of PatternRuntime results from runtime checker - """ - super().__init__(model, op_runtime_results) - @property def validator_name(self) -> str: """Return validator name.""" From 486bba8d13e2c09b0db465faabb9cad7144531e3 Mon Sep 17 00:00:00 2001 From: xieofxie Date: Fri, 12 Jun 2026 12:09:37 +0800 Subject: [PATCH 7/7] always set --- .../model_validator_manager.py | 9 ++++++--- .../core/model_validators/test_validators.py | 18 +++++++++++++----- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py b/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py index 6b240e983..bcf8af563 100644 --- a/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py +++ b/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py @@ -24,6 +24,7 @@ if TYPE_CHECKING: + from ....utils.constants import EPName from ...models.information import Information from ...models.onnx_model import ONNXModel from ...models.runtime_checks import PatternRuntime @@ -76,8 +77,9 @@ def __init__( model: ONNXModel, enabled_validators: list[str] | None = None, op_runtime_results: list[PatternRuntime] | None = None, - device: str | None = None, - ep: str | None = None, + *, + device: str, + ep: EPName, ) -> None: """Initialize validator manager. @@ -89,6 +91,7 @@ def __init__( Used to enrich validators with OP-level information. device: Device type (e.g., "NPU", "GPU", "CPU"). Used to filter validators based on device constraints. + ep: Execution provider name. Forwarded to validators that gate on EP. Raises: ValueError: If model is not valid ONNXModel instance @@ -97,7 +100,7 @@ def __init__( self.model = model self.model_proto = model.get_model() self.op_runtime_results = op_runtime_results or [] - self.device = device or "NPU" + self.device = device self.ep = ep self.enabled_validators = enabled_validators or list(self.VALIDATORS.keys()) diff --git a/tests/unit/analyze/core/model_validators/test_validators.py b/tests/unit/analyze/core/model_validators/test_validators.py index 787224aa6..07085d7fd 100644 --- a/tests/unit/analyze/core/model_validators/test_validators.py +++ b/tests/unit/analyze/core/model_validators/test_validators.py @@ -287,7 +287,9 @@ def test_run_all_validators(self): ) ] - manager = ModelValidatorManager(model, op_runtime_results=op_runtime_results) + manager = ModelValidatorManager( + model, op_runtime_results=op_runtime_results, device="NPU", ep="QNNExecutionProvider" + ) information = manager.run_all_validators() # Should find at least constant folding issue @@ -329,7 +331,11 @@ def test_selective_validators(self): # Enable only constant folding manager = ModelValidatorManager( - model, enabled_validators=["constant_folding"], op_runtime_results=op_runtime_results + model, + enabled_validators=["constant_folding"], + op_runtime_results=op_runtime_results, + device="NPU", + ep="QNNExecutionProvider", ) # Should have exactly one validator @@ -344,11 +350,11 @@ def test_invalid_model_proto_raises_error(self): """Test that invalid model raises AttributeError when trying to get model.""" # Passing None should fail when trying to call get_model() with pytest.raises(AttributeError): - ModelValidatorManager(None) # type: ignore + ModelValidatorManager(None, device="NPU", ep="QNNExecutionProvider") # type: ignore # Passing a string should fail when trying to call get_model() with pytest.raises(AttributeError): - ModelValidatorManager("not a model") # type: ignore + ModelValidatorManager("not a model", device="NPU", ep="QNNExecutionProvider") # type: ignore def test_unknown_validator_logs_warning(self, caplog): """Test that unknown validator names are handled gracefully.""" @@ -373,6 +379,8 @@ def test_unknown_validator_logs_warning(self, caplog): model, enabled_validators=["unknown_validator", "constant_folding"], op_runtime_results=op_runtime_results, + device="NPU", + ep="QNNExecutionProvider", ) # Should only have constant_folding validator (unknown ones are skipped) @@ -434,7 +442,7 @@ def test_skipped_for_two_dim_constant(self): def test_manager_wires_validator_for_openvino_gpu(self): """Manager constructs the validator and surfaces the action for OV GPU.""" model = create_onnx_model_wrapper(_make_batched_const_matmul_proto()) - manager = ModelValidatorManager(model, device="GPU", ep="openvino") + manager = ModelValidatorManager(model, device="GPU", ep="OpenVINOExecutionProvider") names = [v.validator_name for v in manager.validators] assert "BatchedConstMatMulValidator" in names infos = manager.run_all_validators()