diff --git a/src/winml/modelkit/analyze/core/information_engine.py b/src/winml/modelkit/analyze/core/information_engine.py index d4c574b16..a3cfe1e1f 100644 --- a/src/winml/modelkit/analyze/core/information_engine.py +++ b/src/winml/modelkit/analyze/core/information_engine.py @@ -297,6 +297,7 @@ def _check_model(self) -> list[Information]: self._model, op_runtime_results=self._op_runtime_results, device=self._device, + ep=self._ep, ) manager_init_ms = int((time.perf_counter() - manager_init_start) * 1000) diff --git a/src/winml/modelkit/analyze/core/model_validators/__init__.py b/src/winml/modelkit/analyze/core/model_validators/__init__.py index 77cd7b924..4504269da 100644 --- a/src/winml/modelkit/analyze/core/model_validators/__init__.py +++ b/src/winml/modelkit/analyze/core/model_validators/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations from .base import ModelValidator +from .batched_const_matmul_validator import BatchedConstMatMulValidator from .constant_folding_validator import ConstantFoldingValidator from .dynamic_input_validator import DynamicInputValidator from .model_validator_manager import ModelValidatorManager @@ -21,6 +22,7 @@ __all__ = [ + "BatchedConstMatMulValidator", "ConstantFoldingValidator", "DynamicInputValidator", "ModelValidator", diff --git a/src/winml/modelkit/analyze/core/model_validators/base.py b/src/winml/modelkit/analyze/core/model_validators/base.py index cd0432253..226078957 100644 --- a/src/winml/modelkit/analyze/core/model_validators/base.py +++ b/src/winml/modelkit/analyze/core/model_validators/base.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: + from ....utils.constants import EPName from ...models.information import Information from ...models.onnx_model import ONNXModel from ...models.runtime_checks import PatternRuntime @@ -34,19 +35,26 @@ class ModelValidator(ABC): model_proto: ONNX ModelProto extracted from model graph: Shorthand for model_proto.graph op_runtime_results: List of PatternRuntime results from runtime checker (optional) + ep: Execution provider name (optional) + device: Device type, e.g. "NPU", "GPU", "CPU" (optional) """ def __init__( self, model: ONNXModel, op_runtime_results: list[PatternRuntime] | None = None, + ep: EPName | None = None, + device: str | None = None, ) -> None: - """Initialize validator with ONNX model and optional runtime results. + """Initialize validator with ONNX model and optional context. Args: model: ONNXModel wrapper to validate op_runtime_results: List of PatternRuntime results from runtime checker. Used to enrich validators with OP-level information. + ep: Execution provider name. Validators that gate on EP read this. + device: Device type (e.g., "NPU", "GPU", "CPU"). Validators that gate + on device read this. Raises: ValueError: If model is invalid @@ -55,6 +63,9 @@ def __init__( self.model_proto = model.get_model() self.graph = self.model_proto.graph self.op_runtime_results = op_runtime_results or [] + # Annotate explicitly so the EPName Literal is not widened to ``str``. + self.ep: EPName | None = ep + self.device = device logger.debug( f"Initialized {self.validator_name} for model with {len(self.graph.node)} nodes" diff --git a/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py new file mode 100644 index 000000000..91a3114c1 --- /dev/null +++ b/src/winml/modelkit/analyze/core/model_validators/batched_const_matmul_validator.py @@ -0,0 +1,123 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Validator for batched MatMul with a constant operand on OpenVINO GPU. + +OpenVINO GPU's oneDNN gemm cannot select an implementation for a batched +(rank >= 3) MatMul where an operand is a compile-time constant. The identical +gemm with a dynamic operand, and 2D constant gemm, both compile fine. Models +whose batched MatMul weights fold to constants (e.g. transformer disentangled +attention position terms) therefore fail to compile on OpenVINO GPU with: + + [GPU] Failed to select implementation for ... type: gemm + +This validator detects that structural pattern and recommends the +``untie-constant-batched-matmul`` surgery, which makes the constant operand +runtime-valued so gemm implementation selection succeeds. +""" + +from __future__ import annotations + +import logging + +from ...models.information import Action, ActionItem, ActionLevel, Information +from ...utils import infer_ihv_from_ep_name +from .base import ModelValidator + + +logger = logging.getLogger(__name__) + +# Surgery capability enabled when the pattern is detected (kebab-case to match +# the capability registry / autoconf normalization). +_SURGERY_FLAG = "untie-constant-batched-matmul" + + +class BatchedConstMatMulValidator(ModelValidator): + """Detect batched MatMul with a constant operand (OpenVINO GPU only).""" + + @property + def validator_name(self) -> str: + """Name of this validator for logging/reporting.""" + return "BatchedConstMatMulValidator" + + @property + def pattern_id(self) -> str: + """Pattern ID for Information objects.""" + return "MODEL/BatchedConstantMatMul" + + def _is_enabled(self) -> bool: + """Only relevant for OpenVINO (Intel IHV) on GPU.""" + if (self.device or "").upper() != "GPU": + return False + ep = self.ep + if not ep: + return False + try: + from ...models.ihv_type import IHVType + + return infer_ihv_from_ep_name(ep) == IHVType.INTEL + except Exception: # pragma: no cover - defensive + return False + + def validate(self) -> Information | None: + """Detect batched MatMul with a single constant rank>=3 operand.""" + if not self._is_enabled(): + return None + + # Known gap: constants expressed as `Constant` op nodes (rather than + # graph initializers) are not detected here. The `untie-constant-batched + # -matmul` surgery in surgery.py has the same limitation, so detection + # and surgery stay consistent. Most exporters and ORT preprocessing emit + # weights as initializers, so this covers the disentangled-attention case + # in practice; `Constant`-node weights would need handling on both sides. + initializers = {init.name for init in self.graph.initializer} + rank_by_init = {init.name: len(init.dims) for init in self.graph.initializer} + + offenders: list[str] = [] + for node in self.graph.node: + if node.op_type != "MatMul" or len(node.input) != 2: + continue + const_inputs = [name for name in node.input if name in initializers] + # Exactly one constant operand (two-constant MatMuls fold away and + # never reach gemm impl selection). + if len(const_inputs) != 1: + continue + if rank_by_init.get(const_inputs[0], 0) >= 3: + offenders.append(node.name or const_inputs[0]) + + if not offenders: + return None + + examples = ", ".join(offenders[:3]) + action = Action( + pattern_from_id="", + pattern_to_id="", + level=ActionLevel.REQUIRED, + status=None, + action_items=[ + ActionItem(type="GraphOptimization", optimization_options={_SURGERY_FLAG: True}) + ], + details=( + "Enable untie-constant-batched-matmul surgery so the constant " + "operand becomes runtime-valued and OpenVINO GPU can select a " + "gemm implementation." + ), + ) + # https://github.com/openvinotoolkit/openvino/issues/36272 + explanation = ( + f"Model contains {len(offenders)} batched MatMul(s) with a constant " + f"operand (examples: {examples}). OpenVINO GPU's oneDNN gemm cannot " + f"select an implementation for a batched MatMul with a constant " + f"operand, causing a '[GPU] Failed to select implementation ... gemm' " + f"compile failure. The untie-constant-batched-matmul surgery makes " + f"the operand runtime-valued without changing numerics. " + f"It is fixed in openvino==2026.2.0, so no need to apply the surgery " + f"if using that version or later." + ) + return Information( + explanation=explanation, + actions=[action], + pattern_id=self.pattern_id, + status=None, + ) diff --git a/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py b/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py index 29dd0235c..bcf8af563 100644 --- a/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py +++ b/src/winml/modelkit/analyze/core/model_validators/model_validator_manager.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, ClassVar from ...utils.timing_utils import make_timing_logger +from .batched_const_matmul_validator import BatchedConstMatMulValidator from .constant_folding_validator import ConstantFoldingValidator from .dynamic_input_validator import DynamicInputValidator from .pattern_matching_validator import PatternMatchingValidator @@ -23,6 +24,7 @@ if TYPE_CHECKING: + from ....utils.constants import EPName from ...models.information import Information from ...models.onnx_model import ONNXModel from ...models.runtime_checks import PatternRuntime @@ -64,6 +66,10 @@ class ModelValidatorManager: "class": PatternMatchingValidator, "enabled_devices": None, # All devices }, + "batched_const_matmul": { + "class": BatchedConstMatMulValidator, + "enabled_devices": ["GPU"], # OpenVINO GPU gemm impl-selection issue + }, } def __init__( @@ -71,7 +77,9 @@ def __init__( model: ONNXModel, enabled_validators: list[str] | None = None, op_runtime_results: list[PatternRuntime] | None = None, - device: str | None = None, + *, + device: str, + ep: EPName, ) -> None: """Initialize validator manager. @@ -83,6 +91,7 @@ def __init__( Used to enrich validators with OP-level information. device: Device type (e.g., "NPU", "GPU", "CPU"). Used to filter validators based on device constraints. + ep: Execution provider name. Forwarded to validators that gate on EP. Raises: ValueError: If model is not valid ONNXModel instance @@ -91,7 +100,8 @@ def __init__( self.model = model self.model_proto = model.get_model() self.op_runtime_results = op_runtime_results or [] - self.device = device or "NPU" + self.device = device + self.ep = ep self.enabled_validators = enabled_validators or list(self.VALIDATORS.keys()) # Instantiate enabled validators @@ -102,18 +112,25 @@ def __init__( validator_class = validator_config["class"] enabled_devices = validator_config.get("enabled_devices") - # Check device constraint - if enabled_devices is not None and self.device not in enabled_devices: + # Check device constraint (case-insensitive: callers may pass + # "gpu" or "GPU" depending on the build/analyze entry point). + if enabled_devices is not None and (self.device or "").upper() not in { + d.upper() for d in enabled_devices + }: logger.info( f"Validator '{name}' is not enabled for device '{self.device}'. " f"Only enabled for: {enabled_devices}" ) continue + ctor_kwargs: dict = { + "op_runtime_results": self.op_runtime_results, + "ep": self.ep, + "device": self.device, + } + try: - self.validators.append( - validator_class(self.model, op_runtime_results=self.op_runtime_results) - ) + self.validators.append(validator_class(self.model, **ctor_kwargs)) logger.debug(f"Initialized validator: {name}") except Exception: logger.exception(f"Failed to initialize validator {name}") diff --git a/src/winml/modelkit/analyze/core/model_validators/pattern_matching_validator.py b/src/winml/modelkit/analyze/core/model_validators/pattern_matching_validator.py index afdb59a2f..8d0b5aa17 100644 --- a/src/winml/modelkit/analyze/core/model_validators/pattern_matching_validator.py +++ b/src/winml/modelkit/analyze/core/model_validators/pattern_matching_validator.py @@ -12,17 +12,13 @@ import json import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar +from typing import ClassVar from ...models import ModelTag from ...models.information import Action, ActionLevel, Information from .base import ModelValidator -if TYPE_CHECKING: - from ...models.onnx_model import ONNXModel - from ...models.runtime_checks import PatternRuntime - logger = logging.getLogger(__name__) @@ -95,19 +91,6 @@ class PatternMatchingValidator(ModelValidator): ), ] - def __init__( - self, - model: ONNXModel, - op_runtime_results: list[PatternRuntime] | None = None, - ) -> None: - """Initialize validator. - - Args: - model: ONNXModel wrapper to validate - op_runtime_results: List of PatternRuntime results from runtime checker - """ - super().__init__(model, op_runtime_results) - @property def validator_name(self) -> str: """Return validator name.""" diff --git a/src/winml/modelkit/optim/capabilities/surgery.py b/src/winml/modelkit/optim/capabilities/surgery.py index 8b2048f00..0b6ec0768 100644 --- a/src/winml/modelkit/optim/capabilities/surgery.py +++ b/src/winml/modelkit/optim/capabilities/surgery.py @@ -37,3 +37,20 @@ category=CapabilityCategory.SURGERY, default=False, ) + +# Route a constant operand of a batched (rank >= 3) MatMul through a runtime +# no-op so it is no longer a compile-time constant. OpenVINO GPU's oneDNN gemm +# cannot select an implementation for a batched MatMul with a constant operand +# (e.g. transformer disentangled-attention position terms that fold to 3D +# constants); making the operand runtime-valued lets gemm impl selection +# succeed without changing numerics or splitting the batched op. +UNTIE_CONSTANT_BATCHED_MATMUL = BoolCapability( + name="untie-constant-batched-matmul", + ort_name=None, # Custom implementation, not ORT optimizer + description=( + "Make a batched MatMul's constant operand runtime-valued so OpenVINO " + "GPU can select a gemm implementation" + ), + category=CapabilityCategory.SURGERY, + default=False, +) diff --git a/src/winml/modelkit/optim/pipes/surgery.py b/src/winml/modelkit/optim/pipes/surgery.py index fa4fa6bcf..8250a19ad 100644 --- a/src/winml/modelkit/optim/pipes/surgery.py +++ b/src/winml/modelkit/optim/pipes/surgery.py @@ -38,6 +38,7 @@ SURGERY_CAPABILITIES: dict[str, Any] = caps_dict( surgery.CLAMP_CONSTANT_VALUES, surgery.REMOVE_ISNAN_IN_ATTENTION_MASK, + surgery.UNTIE_CONSTANT_BATCHED_MATMUL, ) @@ -57,6 +58,8 @@ class SurgeryPipeConfig(PipeConfig): fix_nan_attention_mask: Replace -inf attention mask with finite value and remove Softmax->IsNaN->Where NaN guard patterns mask_value: Replacement value for -inf (default: -1e3) + untie_constant_batched_matmul: Make a batched MatMul's constant operand + runtime-valued so OpenVINO GPU can select a gemm implementation verbose: Enable verbose logging """ @@ -64,6 +67,7 @@ class SurgeryPipeConfig(PipeConfig): clamp_min: float = -1e3 clamp_max: float = 1e3 remove_isnan_in_attention_mask: bool = False + untie_constant_batched_matmul: bool = False verbose: bool = False @@ -106,6 +110,7 @@ def build_config(cls, **kwargs: Any) -> SurgeryPipeConfig: clamp_min=kwargs.get("clamp_min", -1e3), clamp_max=kwargs.get("clamp_max", 1e3), remove_isnan_in_attention_mask=kwargs.get("remove_isnan_in_attention_mask", False), + untie_constant_batched_matmul=kwargs.get("untie_constant_batched_matmul", False), verbose=kwargs.get("verbose", False), ) @@ -119,7 +124,11 @@ def should_process(cls, config: SurgeryPipeConfig) -> bool: Returns: True if any surgery operation is enabled """ - return config.clamp_constant_values or config.remove_isnan_in_attention_mask + return ( + config.clamp_constant_values + or config.remove_isnan_in_attention_mask + or config.untie_constant_batched_matmul + ) def process(self, model: onnx.ModelProto, config: SurgeryPipeConfig) -> onnx.ModelProto: """Apply surgery operations to the model. @@ -149,6 +158,9 @@ def process(self, model: onnx.ModelProto, config: SurgeryPipeConfig) -> onnx.Mod if config.remove_isnan_in_attention_mask: model_copy = self._remove_isnan_in_attention_mask(model_copy, config.verbose) + if config.untie_constant_batched_matmul: + model_copy = self._untie_constant_batched_matmul(model_copy, config.verbose) + return model_copy def _clamp_constant_values( @@ -319,3 +331,148 @@ def _remove_isnan_in_attention_mask( ) return model + + # ----------------------------------------------------------------- + # untie-constant-batched-matmul + # ----------------------------------------------------------------- + + def _untie_constant_batched_matmul( + self, + model: onnx.ModelProto, + verbose: bool = False, + ) -> onnx.ModelProto: + """Make a batched MatMul's constant operand runtime-valued. + + OpenVINO GPU's oneDNN gemm cannot select an implementation for a batched + (rank >= 3) MatMul where an operand is a compile-time constant: the same + gemm with a dynamic operand, and 2D constant gemm, both compile fine. + Transformer disentangled-attention position terms depend only on weights, + so they fold into 3D constants and hit this case. + + Fix: route each such constant operand through ``Add(const, zero)`` where + ``zero`` is a runtime ``[1]`` tensor built from the first graph input's + *data*: ``Cast(first_input -> float) -> Reshape([-1]) -> Slice([0:1])`` + yields a single element ``elem``, and ``zero = Sub(elem, elem) == 0.0``. + ``zero`` is data-dependent, so OpenVINO's constant folder cannot collapse + the Add back into a packed gemm weight, yet ``+ 0`` leaves the values + unchanged and the single batched MatMul is preserved (no perf cost). + + Assumption: the first graph input has at least one element at runtime. + The ``Slice([0:1])`` is out of bounds for a zero-sized input (e.g. a + dynamic batch dimension fed an empty batch), which would raise at + inference time rather than produce a zero. + """ + from onnx import TensorProto, helper, numpy_helper + + graph = model.graph + initializers = {init.name: init for init in graph.initializer} + + # Collect (matmul_node, operand_index) where the operand is a constant + # initializer of rank >= 3. Skip MatMuls whose operands are all constant + # (those fold away entirely and never reach gemm impl selection). + targets: list[tuple[onnx.NodeProto, int]] = [] + for node in graph.node: + if node.op_type != "MatMul" or len(node.input) != 2: + continue + const_idx = [i for i, name in enumerate(node.input) if name in initializers] + if len(const_idx) != 1: + continue + idx = const_idx[0] + if len(initializers[node.input[idx]].dims) >= 3: + targets.append((node, idx)) + + if not targets: + return model + + if not graph.input: + logger.warning( + "SurgeryPipe: untie-constant-batched-matmul: no graph input to " + "derive a runtime value from; skipping %d MatMul(s)", + len(targets), + ) + return model + + prefix = "winml_ovgpu_untie" + first_input = graph.input[0].name + new_nodes: list[onnx.NodeProto] = [] + new_inits: list[onnx.TensorProto] = [] + + # Build a shape-[1] runtime zero from input *data* (not shape — input + # shapes are static and would be folded). Only ubiquitous ops are used + # so the static analyzer handles them: a single input element is sliced + # out and subtracted from itself. A [1] tensor broadcasts against any + # constant operand, regardless of its rank. + xf = f"{prefix}_xf" + new_nodes.append( + helper.make_node("Cast", [first_input], [xf], to=TensorProto.FLOAT, name=xf) + ) + flat = f"{prefix}_flat" + new_inits.append(numpy_helper.from_array(np.array([-1], dtype=np.int64), f"{prefix}_m1")) + new_nodes.append(helper.make_node("Reshape", [xf, f"{prefix}_m1"], [flat], name=flat)) + elem = f"{prefix}_elem" + # Slice(flat, starts=[0], ends=[1], axes=[0]) -> the first element. + # starts and axes are distinct tensors even though both hold [0], so a + # future edit to one role cannot silently corrupt the other. + starts = f"{prefix}_slice_starts" + ends = f"{prefix}_slice_ends" + axis = f"{prefix}_slice_axis" + new_inits.append(numpy_helper.from_array(np.array([0], dtype=np.int64), starts)) + new_inits.append(numpy_helper.from_array(np.array([1], dtype=np.int64), ends)) + new_inits.append(numpy_helper.from_array(np.array([0], dtype=np.int64), axis)) + new_nodes.append(helper.make_node("Slice", [flat, starts, ends, axis], [elem], name=elem)) + # zero = elem - elem == 0.0 (data-dependent, so it is not folded away). + zero_f32 = f"{prefix}_zero_f32" + new_nodes.append(helper.make_node("Sub", [elem, elem], [zero_f32], name=zero_f32)) + + # A zero must match each operand's dtype (ONNX has no implicit promotion). + zero_by_dtype: dict[int, str] = {int(TensorProto.FLOAT): zero_f32} + + def zero_for(dtype: int) -> str: + name = zero_by_dtype.get(dtype) + if name is None: + name = f"{prefix}_zero_{dtype}" + new_nodes.append(helper.make_node("Cast", [zero_f32], [name], to=dtype, name=name)) + zero_by_dtype[dtype] = name + return name + + untied = 0 + # Index the loop rather than node.name: node names are optional in ONNX + # and exporters routinely leave them blank or duplicated, so deriving + # `dyn` from the name would collide and produce an invalid graph. + for untie_idx, (node, idx) in enumerate(targets): + const_name = node.input[idx] + dtype = initializers[const_name].data_type + if dtype not in (TensorProto.FLOAT, TensorProto.FLOAT16, TensorProto.DOUBLE): + continue + dyn = f"{prefix}_untied{untie_idx}_in{idx}" + new_nodes.append( + helper.make_node("Add", [const_name, zero_for(dtype)], [dyn], name=dyn) + ) + node.input[idx] = dyn + untied += 1 + if verbose: + logger.info( + " untie-constant-batched-matmul: %s input[%d] %s -> %s", + node.name, + idx, + const_name, + dyn, + ) + + if untied == 0: + return model + + graph.initializer.extend(new_inits) + # Prepend new nodes: their inputs are only graph inputs / initializers, + # so placing them first keeps the graph topologically sorted. + existing = list(graph.node) + del graph.node[:] + graph.node.extend(new_nodes + existing) + + logger.info( + "SurgeryPipe: untie-constant-batched-matmul: untied %d batched " + "MatMul constant operand(s)", + untied, + ) + + return model diff --git a/tests/unit/analyze/core/model_validators/test_validators.py b/tests/unit/analyze/core/model_validators/test_validators.py index ddc83c7be..07085d7fd 100644 --- a/tests/unit/analyze/core/model_validators/test_validators.py +++ b/tests/unit/analyze/core/model_validators/test_validators.py @@ -287,7 +287,9 @@ def test_run_all_validators(self): ) ] - manager = ModelValidatorManager(model, op_runtime_results=op_runtime_results) + manager = ModelValidatorManager( + model, op_runtime_results=op_runtime_results, device="NPU", ep="QNNExecutionProvider" + ) information = manager.run_all_validators() # Should find at least constant folding issue @@ -329,7 +331,11 @@ def test_selective_validators(self): # Enable only constant folding manager = ModelValidatorManager( - model, enabled_validators=["constant_folding"], op_runtime_results=op_runtime_results + model, + enabled_validators=["constant_folding"], + op_runtime_results=op_runtime_results, + device="NPU", + ep="QNNExecutionProvider", ) # Should have exactly one validator @@ -344,11 +350,11 @@ def test_invalid_model_proto_raises_error(self): """Test that invalid model raises AttributeError when trying to get model.""" # Passing None should fail when trying to call get_model() with pytest.raises(AttributeError): - ModelValidatorManager(None) # type: ignore + ModelValidatorManager(None, device="NPU", ep="QNNExecutionProvider") # type: ignore # Passing a string should fail when trying to call get_model() with pytest.raises(AttributeError): - ModelValidatorManager("not a model") # type: ignore + ModelValidatorManager("not a model", device="NPU", ep="QNNExecutionProvider") # type: ignore def test_unknown_validator_logs_warning(self, caplog): """Test that unknown validator names are handled gracefully.""" @@ -373,9 +379,71 @@ def test_unknown_validator_logs_warning(self, caplog): model, enabled_validators=["unknown_validator", "constant_folding"], op_runtime_results=op_runtime_results, + device="NPU", + ep="QNNExecutionProvider", ) # Should only have constant_folding validator (unknown ones are skipped) assert len(manager.validators) == 1 assert manager.validators[0].validator_name == "ConstantFoldingValidator" assert "Unknown validator" in caplog.text + + +def _make_batched_const_matmul_proto(const_rank: int = 3): + """Model: data [2,3,4] @ W(const) [2,4,5] -> out [2,3,5].""" + import numpy as np + from onnx import numpy_helper + + w_shape = [2, 4, 5] if const_rank == 3 else [4, 5] + w = numpy_helper.from_array(np.zeros(w_shape, dtype=np.float32), "W") + matmul = helper.make_node("MatMul", ["data", "W"], ["out"], name="batched_matmul") + graph = helper.make_graph( + [matmul], + "batched_const_matmul", + [helper.make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4])], + [helper.make_tensor_value_info("out", TensorProto.FLOAT, [2, 3, 5])], + initializer=[w], + ) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + +class TestBatchedConstMatMulValidator: + """OpenVINO-GPU batched constant MatMul detector.""" + + def _validate(self, proto, ep, device): + from winml.modelkit.analyze.core.model_validators import BatchedConstMatMulValidator + + model = create_onnx_model_wrapper(proto) + return BatchedConstMatMulValidator(model, ep=ep, device=device).validate() + + def test_detects_for_openvino_gpu(self): + """Emits a GraphOptimization action enabling the surgery for OV GPU.""" + info = self._validate(_make_batched_const_matmul_proto(), "openvino", "GPU") + assert info is not None + assert info.pattern_id == "MODEL/BatchedConstantMatMul" + items = info.actions[0].action_items + assert items[0].type == "GraphOptimization" + assert items[0].optimization_options == {"untie-constant-batched-matmul": True} + + def test_skipped_for_openvino_npu(self): + """Device-gated: NPU is unaffected.""" + assert self._validate(_make_batched_const_matmul_proto(), "openvino", "NPU") is None + + def test_skipped_for_non_intel_gpu(self): + """IHV-gated: a non-Intel GPU EP is unaffected.""" + info = self._validate(_make_batched_const_matmul_proto(), "DmlExecutionProvider", "GPU") + assert info is None + + def test_skipped_for_two_dim_constant(self): + """Rank-2 constant gemm compiles on OV GPU; not flagged.""" + info = self._validate(_make_batched_const_matmul_proto(const_rank=2), "openvino", "GPU") + assert info is None + + def test_manager_wires_validator_for_openvino_gpu(self): + """Manager constructs the validator and surfaces the action for OV GPU.""" + model = create_onnx_model_wrapper(_make_batched_const_matmul_proto()) + manager = ModelValidatorManager(model, device="GPU", ep="OpenVINOExecutionProvider") + names = [v.validator_name for v in manager.validators] + assert "BatchedConstMatMulValidator" in names + infos = manager.run_all_validators() + assert any(i.pattern_id == "MODEL/BatchedConstantMatMul" for i in infos) diff --git a/tests/unit/optim/pipes/test_pipe_surgery.py b/tests/unit/optim/pipes/test_pipe_surgery.py index 6ab7e0034..16df08d79 100644 --- a/tests/unit/optim/pipes/test_pipe_surgery.py +++ b/tests/unit/optim/pipes/test_pipe_surgery.py @@ -342,3 +342,158 @@ def test_surgery_pipe_runs_last(self) -> None: # SurgeryPipe should be last in the list assert PIPES[-1].name == "surgery" + + +# ============================================================================= +# UNTIE-CONSTANT-BATCHED-MATMUL TESTS +# ============================================================================= + + +def _make_batched_const_matmul_model( + *, + const_rank: int = 3, + const_on_rhs: bool = True, +) -> onnx.ModelProto: + """Build a model with a batched MatMul that has one constant operand. + + data [2,3,4] @ W(const) [2,4,5] -> out [2,3,5] (const on rhs), or the + transposed arrangement when ``const_on_rhs`` is False. + """ + from onnx import TensorProto, helper + + rng = np.random.RandomState(0) + if const_on_rhs: + data_shape, w_shape, out_shape = [2, 3, 4], [2, 4, 5], [2, 3, 5] + mm_inputs = ["data", "W"] + else: + data_shape, w_shape, out_shape = [2, 4, 5], [2, 3, 4], [2, 3, 5] + mm_inputs = ["W", "data"] + + if const_rank == 2: + w_shape = w_shape[1:] + + w = numpy_helper.from_array(rng.randn(*w_shape).astype(np.float32), "W") + matmul = helper.make_node("MatMul", mm_inputs, ["out"], name="batched_matmul") + graph = helper.make_graph( + [matmul], + "test_batched_const_matmul", + [helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape)], + [helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape)], + initializer=[w], + ) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + +class TestUntieConstantBatchedMatmulCapability: + """Capability/config plumbing for untie-constant-batched-matmul.""" + + def test_capability_exists(self) -> None: + """Capability is registered with a None ort_name (custom impl).""" + assert "untie-constant-batched-matmul" in SURGERY_CAPABILITIES + assert SURGERY_CAPABILITIES["untie-constant-batched-matmul"].ort_name is None + + def test_build_config_enable_via_kwarg(self) -> None: + """Flag can be toggled through build_config.""" + config = SurgeryPipe.build_config(untie_constant_batched_matmul=True) + assert config.untie_constant_batched_matmul is True + + def test_should_process_true_when_enabled(self) -> None: + """should_process is True when only this surgery is enabled.""" + config = SurgeryPipeConfig(untie_constant_batched_matmul=True) + assert SurgeryPipe.should_process(config) is True + + +class TestUntieConstantBatchedMatmulProcess: + """Graph transform behavior.""" + + def test_constant_operand_becomes_runtime_valued(self) -> None: + """The MatMul no longer consumes the initializer directly.""" + model = _make_batched_const_matmul_model() + result = SurgeryPipe().process(model, SurgeryPipeConfig(untie_constant_batched_matmul=True)) + + matmul = next(n for n in result.graph.node if n.op_type == "MatMul") + initializer_names = {init.name for init in result.graph.initializer} + # No MatMul input is a direct initializer anymore. + assert not (set(matmul.input) & initializer_names) + # An Add node now produces the (formerly constant) operand. + add_nodes = [n for n in result.graph.node if n.op_type == "Add"] + assert len(add_nodes) == 1 + assert add_nodes[0].output[0] in matmul.input + # Graph remains structurally valid. + onnx.checker.check_model(result) + + def test_numerics_unchanged(self) -> None: + """+0 tie leaves outputs bit-for-bit identical on ORT CPU.""" + import onnxruntime as ort + + model = _make_batched_const_matmul_model() + transformed = SurgeryPipe().process( + model, SurgeryPipeConfig(untie_constant_batched_matmul=True) + ) + + rng = np.random.RandomState(7) + feed = {"data": rng.randn(2, 3, 4).astype(np.float32)} + + ref = ort.InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ).run(None, feed)[0] + got = ort.InferenceSession( + transformed.SerializeToString(), providers=["CPUExecutionProvider"] + ).run(None, feed)[0] + np.testing.assert_array_equal(ref, got) + + def test_two_dim_constant_is_left_untouched(self) -> None: + """Rank-2 constant gemm compiles on OV GPU, so it must not be rewritten.""" + model = _make_batched_const_matmul_model(const_rank=2) + result = SurgeryPipe().process(model, SurgeryPipeConfig(untie_constant_batched_matmul=True)) + assert not any(n.op_type == "Add" for n in result.graph.node) + + def test_constant_on_lhs_is_handled(self) -> None: + """A constant rank-3 operand on the LHS is untied too.""" + model = _make_batched_const_matmul_model(const_on_rhs=False) + result = SurgeryPipe().process(model, SurgeryPipeConfig(untie_constant_batched_matmul=True)) + assert any(n.op_type == "Add" for n in result.graph.node) + + def test_duplicate_node_names_do_not_collide(self) -> None: + """Two target MatMuls with empty names produce a valid graph. + + Node names are optional in ONNX; exporters routinely leave them blank. + The generated dynamic-operand names must be unique regardless, or the + transformed graph would have colliding tensor names and fail validation. + """ + from onnx import TensorProto, helper + + rng = np.random.RandomState(0) + w1 = numpy_helper.from_array(rng.randn(2, 4, 5).astype(np.float32), "W1") + w2 = numpy_helper.from_array(rng.randn(2, 5, 6).astype(np.float32), "W2") + # Both MatMuls deliberately left unnamed (name=""). + mm1 = helper.make_node("MatMul", ["data", "W1"], ["mid"], name="") + mm2 = helper.make_node("MatMul", ["mid", "W2"], ["out"], name="") + graph = helper.make_graph( + [mm1, mm2], + "test_dup_names", + [helper.make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4])], + [helper.make_tensor_value_info("out", TensorProto.FLOAT, [2, 3, 6])], + initializer=[w1, w2], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + result = SurgeryPipe().process(model, SurgeryPipeConfig(untie_constant_batched_matmul=True)) + + # Both constants are untied and the graph stays structurally valid. + add_nodes = [n for n in result.graph.node if n.op_type == "Add"] + assert len(add_nodes) == 2 + assert len({n.output[0] for n in add_nodes}) == 2 + onnx.checker.check_model(result) + + # Numerics are unchanged versus the original model. + import onnxruntime as ort + + feed = {"data": np.random.RandomState(7).randn(2, 3, 4).astype(np.float32)} + ref = ort.InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ).run(None, feed)[0] + got = ort.InferenceSession( + result.SerializeToString(), providers=["CPUExecutionProvider"] + ).run(None, feed)[0] + np.testing.assert_array_equal(ref, got)