diff --git a/src/winml/modelkit/analyze/core/runtime_checker_query.py b/src/winml/modelkit/analyze/core/runtime_checker_query.py index d5962a036..145f273e3 100644 --- a/src/winml/modelkit/analyze/core/runtime_checker_query.py +++ b/src/winml/modelkit/analyze/core/runtime_checker_query.py @@ -2491,22 +2491,210 @@ def get_pattern_id(is_qdq: bool) -> str: parquet_rules_ms = _elapsed_ms(parquet_rules_start) return _finish(final_result, outcome="parquet_rules") + def _load_parquet_pattern_rule_table( + self, + pattern_name: str, + op_domain: ONNXDomain, + opset_version: int, + ) -> tuple[pd.DataFrame | None, Path, _ParquetConditionTree | None]: + """Load per-pattern parquet rule table with cache. + + Returns: + tuple[pd.DataFrame | None, Path, _ParquetConditionTree | None]: + Loaded dataframe when available, otherwise None, + the resolved parquet path used for lookup, + and optional pre-built condition tree. + """ + parquet_name = ( + f"{pattern_name}_{self.ep_name}_{self.device_type.upper()}_{op_domain.name}" + f"_opset{opset_version}.parquet" + ) + parquet_path = resolve_rule_parquet_path(parquet_name) + + cache_key = (pattern_name, op_domain.value, opset_version, False) + if cache_key in self._parquet_rule_table_cache: + _log_parquet_cache_hit(parquet_path, scope="instance") + return ( + self._parquet_rule_table_cache[cache_key], + parquet_path, + self._parquet_condition_tree_cache.get(cache_key), + ) + + table_df = _get_or_load_parquet_table_global(parquet_path) + condition_tree = _build_condition_tree(table_df) + self._parquet_rule_table_cache[cache_key] = table_df + self._parquet_condition_tree_cache[cache_key] = condition_tree + return table_df, parquet_path, condition_tree + def run_for_subgraph( self, pattern_match: PatternMatchResult, run_unknown_op: bool = False, ) -> PatternRuntime: - """Run runtime check for subgraph pattern via per-node checks.""" + """Run runtime check for subgraph pattern via parquet rule lookup. + + Strategy mirrors ``run_for_node``'s parquet-based path: + 1. Extract conditions from pattern match. + 2. Resolve the pattern parquet table for + ``(pattern_name, ep, device, domain, opset)``. + 3. Look up the matching row and return its compile/run result. + 4. Fallback to per-node checking when the table or matching row is missing. + + Args: + pattern_match: PatternMatchResult containing pattern information. + run_unknown_op: If True, attempt local EP check for unknown ops in + the per-node fallback path. + + Returns: + PatternRuntime with check results. + """ + pattern_id = pattern_match.pattern.pattern_id pattern_name = pattern_match.pattern.__class__.__name__ - logger.info( - "Pattern-level aggregated rules are removed; checking individual operators for '%s'", - pattern_name, + + # Step 1: Extract conditions from PatternMatchResult + try: + conditions, infinite_properties = get_query_conditions_for_pattern( + pattern_match, + pattern_name, + self.opset_versions, + dynamic_axis_strict_mode=self.dynamic_axis_strict_mode, + ) + except OpOptionalInputSupportError as e: + logger.error("OpOptionalInputSupportError for pattern '%s': %s", pattern_name, e) + return PatternRuntime( + pattern_id=pattern_id, + result=RuntimeTestResult( + compile=False, + run=False, + no_data=True, + reason="optional_input_properties_not_found", + debug_details={ + "pattern_name": pattern_name, + "error_message": str(e), + "table_path": "", + "table_file": "", + }, + ), + alternatives=self.alternatives, + pattern_match=pattern_match, + ) + except Exception as e: + logger.error("Failed to extract conditions for pattern '%s': %s", pattern_name, e) + return PatternRuntime( + pattern_id=pattern_id, + result=RuntimeTestResult( + compile=False, + run=False, + no_data=True, + reason="pattern_conditions_extraction_failed", + debug_details={ + "pattern_name": pattern_name, + "error_message": str(e), + }, + ), + alternatives=self.alternatives, + pattern_match=pattern_match, + ) + + # Step 2: Determine domain & opset for the parquet table file name. + # Prefer ai.onnx (com.microsoft opset is always 1, so we use ai.onnx + # for naming when available); otherwise fall back to the first domain. + if ONNXDomain.AI_ONNX in self.opset_versions: + table_domain = ONNXDomain.AI_ONNX + table_opset = self.opset_versions[ONNXDomain.AI_ONNX] + else: + table_domain, table_opset = next(iter(self.opset_versions.items())) + + # Step 3: Load the parquet rule table for this pattern + table_df, parquet_path, condition_tree = self._load_parquet_pattern_rule_table( + pattern_name, table_domain, table_opset ) - return self._run_for_subgraph_per_node( - pattern_match, + parquet_file = parquet_path.name + parquet_path_norm = _normalize_table_path(parquet_path) + + if table_df is None: + logger.info( + "No pattern parquet '%s' found for '%s', checking individual operators", + parquet_file, + pattern_name, + ) + return self._run_for_subgraph_per_node(pattern_match, pattern_name, run_unknown_op) + + # Step 4: Build filter conditions and look up the matching row + pattern_columns = condition_tree.condition_columns if condition_tree is not None else [] + table_filter_conditions = _build_table_filter_conditions( + conditions, + pattern_columns, + infinite_properties, + f"pattern {pattern_name}", + ) + parquet_filter_conditions = { + k: encode_rule_condition_value_for_parquet(v) + for k, v in table_filter_conditions.items() + } + query_signature = _build_query_signature(pattern_columns, parquet_filter_conditions) + + cache_key = ( pattern_name, - run_unknown_op, + table_domain.value, + table_opset, + False, + query_signature, + ) + if cache_key in self._node_result_cache: + cached = self._node_result_cache[cache_key] + return PatternRuntime( + pattern_id=pattern_id, + result=cached.result, + alternatives=self.alternatives, + pattern_match=pattern_match, + ) + + matched_row = None + row_position = _lookup_row_position_in_condition_tree( + condition_tree, parquet_filter_conditions ) + if row_position is not None: + matched_row = table_df.iloc[row_position] + else: + ret = query_table_exact_match(table_df, parquet_filter_conditions) + if not ret.empty: + matched_row = ret.iloc[0] + + if matched_row is None: + logger.info( + "Pattern parquet '%s' loaded but properties not matched for '%s': %s", + parquet_file, + pattern_name, + table_filter_conditions, + ) + return self._run_for_subgraph_per_node(pattern_match, pattern_name, run_unknown_op) + + compile_run = matched_row.get("compile_run_success", (False, False)) + compile_result = bool(compile_run[0]) + run_result = bool(compile_run[1]) + + result = RuntimeTestResult( + compile=compile_result, + run=run_result, + reason="", + no_data=False, + debug_details={ + "table_path": parquet_path_norm, + "table_file": parquet_file, + "opset_version": table_opset, + "lookup_columns": pattern_columns, + "query_signature": query_signature, + }, + ) + pattern_runtime = PatternRuntime( + pattern_id=pattern_id, + result=result, + alternatives=self.alternatives, + pattern_match=pattern_match, + ) + self._node_result_cache[cache_key] = pattern_runtime + return pattern_runtime def _run_for_subgraph_per_node( self, diff --git a/src/winml/modelkit/pattern/__init__.py b/src/winml/modelkit/pattern/__init__.py index 38f0112c4..ba07928e4 100644 --- a/src/winml/modelkit/pattern/__init__.py +++ b/src/winml/modelkit/pattern/__init__.py @@ -81,6 +81,10 @@ ReshapeTransposeReshapeOverlyHighDimPattern, ReshapeTransposeReshapeOverlyHighDimPatternInputGenerator, ) +from .unsqueeze_cast_patterns import ( + UnsqueezeCastPattern, + UnsqueezeCastPatternInputGenerator, +) __all__ = [ @@ -139,6 +143,8 @@ "TransposedSingleLayerNormalizationPatternInputGenerator", "TransposedSingleRMSNormalizationPattern", "TransposedSingleRMSNormalizationPatternInputGenerator", + "UnsqueezeCastPattern", + "UnsqueezeCastPatternInputGenerator", "get_pattern_input_generator", "get_registered_pattern_input_generators", "make_single_op_pattern", diff --git a/src/winml/modelkit/pattern/base.py b/src/winml/modelkit/pattern/base.py index 245e5d51b..c164b04cb 100644 --- a/src/winml/modelkit/pattern/base.py +++ b/src/winml/modelkit/pattern/base.py @@ -965,14 +965,21 @@ def _infer_type_mapping(self, skeleton_match_result: "SkeletonMatchResult") -> d Dictionary mapping type parameters (e.g., 'T') to actual types (e.g., 'tensor(float)'). """ schema = self.get_schema() - type_param_to_type = {} + matcher = skeleton_match_result.matcher + type_param_to_type: dict[str, str] = {} for idx, input_param in enumerate(schema.inputs): - if idx < len(skeleton_match_result.inputs): - tensor_name = skeleton_match_result.inputs[idx] - actual_type = skeleton_match_result.matcher.get_tensor_type_str(tensor_name) - if actual_type and input_param.type_str: - type_param_to_type[input_param.type_str] = actual_type + if idx < len(skeleton_match_result.inputs) and input_param.type_str: + actual_type = matcher.get_tensor_type_str(skeleton_match_result.inputs[idx]) + if actual_type: + type_param_to_type.setdefault(input_param.type_str, actual_type) + + if schema.outputs and skeleton_match_result.output: + output_param = schema.outputs[0] + if output_param.type_str: + actual_type = matcher.get_tensor_type_str(skeleton_match_result.output) + if actual_type: + type_param_to_type.setdefault(output_param.type_str, actual_type) return type_param_to_type diff --git a/src/winml/modelkit/pattern/rules/default.json b/src/winml/modelkit/pattern/rules/default.json index bb9794e2e..8ca9f23e0 100644 --- a/src/winml/modelkit/pattern/rules/default.json +++ b/src/winml/modelkit/pattern/rules/default.json @@ -180,6 +180,14 @@ "reason": "Merged axes reduce Transpose dimensionality for better hardware compatibility" } ] + }, + { + "pattern_id": "SUBGRAPH/UnsqueezeCastPattern", + "pattern_class": "UnsqueezeCastPattern", + "module": "winml.modelkit.pattern.unsqueeze_cast_patterns", + "enabled": true, + "flag_name": "unsqueezecast", + "description": "Unsqueeze followed by Cast(to=FLOAT) on a constant-axes Unsqueeze" } ] } diff --git a/src/winml/modelkit/pattern/unsqueeze_cast_patterns.py b/src/winml/modelkit/pattern/unsqueeze_cast_patterns.py new file mode 100644 index 000000000..4d6a6216a --- /dev/null +++ b/src/winml/modelkit/pattern/unsqueeze_cast_patterns.py @@ -0,0 +1,231 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Unsqueeze-Cast pattern for ONNX models. + +This module provides a pattern for matching ``Unsqueeze -> Cast(to=FLOAT)`` +subgraphs. The exemplar case lives in google-t5/t5-small.onnx where +``/model/decoder/Unsqueeze_1`` is followed by ``/model/decoder/Cast_1`` to +promote an integer attention-mask tensor to float32 before being added to +attention scores. +""" + +from typing import Any + +import numpy as np +from onnx import TensorProto +from onnx.defs import OpSchema + +from ..onnx import ONNXDomain +from .base import ( + Pattern, + PatternInputGenerator, + PatternMismatchedError, + PatternSchema, + Skeleton, + register_pattern_input_generator, +) +from .match import SkeletonMatchResult +from .op_input_gen import InputShapeConstraint + + +# Cast `to` value this pattern is specialised for: TensorProto.FLOAT == 1. +_CAST_TO_FLOAT32 = int(TensorProto.FLOAT) + + +_UNSQUEEZE_CAST_SCHEMA = PatternSchema( + name="UnsqueezeCastPattern", + doc=( + "Unsqueeze followed by Cast(to=FLOAT) pattern.\n" + "Computes: output = Cast(Unsqueeze(data, axes), to=tensor(float))\n" + "\n" + "Attributes:\n" + "- axes: int64 axes input to the Unsqueeze node (required to be constant).\n" + "- to: Target dtype enum for the Cast node; constrained to " + "TensorProto.FLOAT (1).\n" + ), + type_constraints=[ + OpSchema.TypeConstraintParam( + type_param_str="T1", + allowed_type_strs=[ + "tensor(float16)", + "tensor(float)", + "tensor(double)", + "tensor(uint8)", + "tensor(uint16)", + "tensor(uint32)", + "tensor(uint64)", + "tensor(int8)", + "tensor(int16)", + "tensor(int32)", + "tensor(int64)", + "tensor(bfloat16)", + "tensor(bool)", + ], + description="Constrain input type to all numeric tensor types.", + ), + OpSchema.TypeConstraintParam( + type_param_str="T2", + allowed_type_strs=["tensor(float)"], + description="Constrain output type to tensor(float).", + ), + ], + inputs=[ + OpSchema.FormalParameter( + name="data", + type_str="T1", + description="Input tensor to be unsqueezed and cast to float.", + param_option=OpSchema.FormalParameterOption.Single, + is_homogeneous=True, + min_arity=1, + differentiation_category=OpSchema.DifferentiationCategory.Differentiable, + ), + ], + outputs=[ + OpSchema.FormalParameter( + name="output", + type_str="T2", + description="Cast output tensor with an extra axis inserted.", + param_option=OpSchema.FormalParameterOption.Single, + is_homogeneous=True, + min_arity=1, + differentiation_category=OpSchema.DifferentiationCategory.Differentiable, + ) + ], + attributes={ + "axes": OpSchema.Attribute( + name="axes", + description="Axes argument of the Unsqueeze node (constant int64 tensor).", + type=OpSchema.AttrType.INTS, + required=True, + ), + "to": OpSchema.Attribute( + name="to", + description="Cast `to` enum; constrained to TensorProto.FLOAT (1).", + type=OpSchema.AttrType.INT, + required=True, + ), + }, +) + + +class UnsqueezeCastPattern(Pattern): + """Pattern for Unsqueeze followed by Cast(to=FLOAT). + + Node topology: + - Node 0 (Unsqueeze): Unsqueeze(data, axes) + - Node 1 (Cast): Cast(unsqueeze_output, to=FLOAT) + + The ``axes`` input of the Unsqueeze node must be a constant (initializer + or Constant-node output); otherwise the match is rejected. The Cast + ``to`` attribute is constrained to ``TensorProto.FLOAT`` (1). + """ + + def get_skeleton(self) -> Skeleton: + """Return the skeleton structure for the UnsqueezeCast pattern.""" + node_op_types = ["Unsqueeze", "Cast"] + node_domains = [ONNXDomain.AI_ONNX] * len(node_op_types) + + edges = [ + (-1, 0, 0, 0), # input data -> Unsqueeze[0] + (0, 0, 1, 0), # Unsqueeze output -> Cast[0] + ] + + return Skeleton( + node_op_types=node_op_types, + node_domains=node_domains, + edges=edges, + exit_nodes=[1], + n_inputs=1, + ) + + def get_internal_constants_and_attributes( + self, + inputs: dict[str, np.ndarray], + attributes: dict[str, Any], + is_constant_map: dict[str, bool], + domain_versions: dict[ONNXDomain, int], + ) -> tuple[list[tuple[int, int, np.ndarray]], dict[tuple[int, str], Any]]: + """Return internal constants for axes and attribute constraint for Cast.to.""" + internal_constants: list[tuple[int, int, np.ndarray]] = [ + (0, 1, np.array(attributes["axes"], dtype=np.int64)), + ] + internal_attributes: dict[tuple[int, str], Any] = { + (1, "to"): _CAST_TO_FLOAT32, + } + return internal_constants, internal_attributes + + def _infer_schema_attributes( + self, skeleton_match_result: SkeletonMatchResult + ) -> dict[str, Any]: + """Infer ``axes`` (from Unsqueeze) and ``to`` (from Cast) attributes.""" + attributes: dict[str, Any] = {} + matcher = skeleton_match_result.matcher + matched_nodes = skeleton_match_result.matched_nodes + + unsqueeze_node = matched_nodes[0] + if len(unsqueeze_node.input) <= 1: + raise PatternMismatchedError("Unsqueeze node missing axes input") + axes_input_name = unsqueeze_node.input[1] + if axes_input_name not in matcher.tensor_values: + raise PatternMismatchedError( + f"Unsqueeze axes input '{axes_input_name}' is not a constant" + ) + attributes["axes"] = tuple(matcher.tensor_values[axes_input_name].tolist()) + + cast_node = matched_nodes[1] + to_found = False + for attr in cast_node.attribute: + if attr.name == "to": + attributes["to"] = int(attr.i) + to_found = True + break + if not to_found: + raise PatternMismatchedError("Cast node missing 'to' attribute") + + return attributes + + def get_schema(self) -> PatternSchema: + """Return the schema definition for the UnsqueezeCast pattern.""" + return _UNSQUEEZE_CAST_SCHEMA + + +@register_pattern_input_generator +class UnsqueezeCastPatternInputGenerator(PatternInputGenerator): + """PatternInputGenerator for UnsqueezeCastPattern.""" + + pattern = UnsqueezeCastPattern() + registration_name = "UnsqueezeCastPattern" + + def get_finite_attribute_sets(self) -> dict[str, list[Any]]: + """Return finite attribute sets (empty for this pattern).""" + return {} + + def get_input_and_infinite_attribute_combinations( + self, + ) -> list[dict[str, Any]]: + """Generate input/attribute combinations for testing.""" + return [ + { + "data": InputShapeConstraint((2, 3)), + "axes": (1,), + "to": _CAST_TO_FLOAT32, + }, + { + "data": InputShapeConstraint((4, 5, 6)), + "axes": (0,), + "to": _CAST_TO_FLOAT32, + }, + ] + + def derive_properties(self, properties: dict) -> dict: + """Add convenience properties for parameterised testing.""" + item = properties.copy() + item["axes_dim"] = len(item["attr_axes"]) + return item + + def get_infinite_property_names(self) -> list[str]: + """Return names of properties with infinite possible values.""" + return ["attr_axes", "data_shape"] diff --git a/tests/unit/analyze/core/test_unified_pattern_config.py b/tests/unit/analyze/core/test_unified_pattern_config.py index 575091fba..4ec61235b 100644 --- a/tests/unit/analyze/core/test_unified_pattern_config.py +++ b/tests/unit/analyze/core/test_unified_pattern_config.py @@ -19,9 +19,10 @@ def test_load_default_config(self): htp_patterns = config.get_htp_patterns() # Should load all patterns from default.json - # 8 patterns: Gelu1-4, MatMulAdd, LayerNormPow, LayerNormMul, ReshapeTransposeReshape - assert len(skeleton_patterns) == 8, ( - f"Expected 8 skeleton patterns, got {len(skeleton_patterns)}" + # 9 patterns: Gelu1-4, MatMulAdd, LayerNormPow, LayerNormMul, + # ReshapeTransposeReshape, UnsqueezeCast + assert len(skeleton_patterns) == 9, ( + f"Expected 9 skeleton patterns, got {len(skeleton_patterns)}" ) assert len(htp_patterns) == 1, f"Expected 1 HTP pattern, got {len(htp_patterns)}" @@ -32,6 +33,7 @@ def test_load_default_config(self): "SUBGRAPH/GemmPattern", # MatMulAdd "SUBGRAPH/LayerNormalizationPattern", # Shared by Pow and Mul variants "SUBGRAPH/ReshapeTransposeReshapeOverlyHighDimPattern", + "SUBGRAPH/UnsqueezeCastPattern", } assert skeleton_pattern_ids == expected_ids, f"Pattern IDs mismatch: {skeleton_pattern_ids}" @@ -43,9 +45,9 @@ def test_load_qnn_config_with_inheritance(self): htp_patterns = config.get_htp_patterns() # Should load all patterns from default + QNN overrides - # 9 patterns: 8 from default + TransposeAttentionPattern from QNN - assert len(skeleton_patterns) == 9, ( - f"Expected 9 skeleton patterns, got {len(skeleton_patterns)}" + # 10 patterns: 9 from default + TransposeAttentionPattern from QNN + assert len(skeleton_patterns) == 10, ( + f"Expected 10 skeleton patterns, got {len(skeleton_patterns)}" ) # HTP patterns should be inherited from default assert len(htp_patterns) == 1, f"Expected 1 HTP pattern, got {len(htp_patterns)}" @@ -160,7 +162,7 @@ def test_missing_ihv_config_falls_back_to_default(self): # Should load default patterns with a warning skeleton_patterns = config.get_skeleton_patterns() - assert len(skeleton_patterns) == 8, "Should fall back to default patterns" + assert len(skeleton_patterns) == 9, "Should fall back to default patterns" def test_alternatives_with_pattern_class(self): """Test that alternatives with pattern_class field are loaded correctly.""" diff --git a/tests/unit/analyze/pattern/test_pattern_input_generator.py b/tests/unit/analyze/pattern/test_pattern_input_generator.py index 1e4dd1603..1abd8d72a 100644 --- a/tests/unit/analyze/pattern/test_pattern_input_generator.py +++ b/tests/unit/analyze/pattern/test_pattern_input_generator.py @@ -31,7 +31,7 @@ class TestPatternInputGeneratorRegistry: def test_all_patterns_registered(self) -> None: """Test that all patterns are registered.""" registered = get_registered_pattern_input_generators() - assert len(registered) == 19 + assert len(registered) == 20 def test_get_pattern_input_generator(self) -> None: """Test retrieving pattern generators by name.""" diff --git a/tests/unit/analyze/pattern/test_unsqueeze_cast_patterns.py b/tests/unit/analyze/pattern/test_unsqueeze_cast_patterns.py new file mode 100644 index 000000000..ee8091dcb --- /dev/null +++ b/tests/unit/analyze/pattern/test_unsqueeze_cast_patterns.py @@ -0,0 +1,147 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Tests for UnsqueezeCastPattern. + +The exemplar production case is the ``/model/decoder/Unsqueeze_1`` -> +``/model/decoder/Cast_1`` pair in google-t5/t5-small.onnx where a 4-D +attention-mask tensor is unsqueezed and then cast to float32. +""" + +from __future__ import annotations + +import numpy as np +import onnx +import pytest +from onnx import TensorProto, helper + +from winml.modelkit.pattern import ( + PatternMatcher, + UnsqueezeCastPattern, +) + +from .conftest import TEST_DOMAIN_VERSIONS + + +_FLOAT = int(TensorProto.FLOAT) + + +def _build_unsqueeze_cast_model( + *, + data_shape: tuple[int, ...] = (2, 3), + data_elem_type: int = TensorProto.INT64, + axes: tuple[int, ...] = (1,), + cast_to: int = _FLOAT, + axes_as_initializer: bool = True, +) -> onnx.ModelProto: + """Build a minimal ONNX model containing only Unsqueeze -> Cast.""" + out_rank = len(data_shape) + len(axes) + norm_axes = sorted(a if a >= 0 else a + out_rank for a in axes) + out_shape: list[int] = [] + data_iter = iter(data_shape) + for i in range(out_rank): + if i in norm_axes: + out_shape.append(1) + else: + out_shape.append(next(data_iter)) + + data = helper.make_tensor_value_info("data", data_elem_type, list(data_shape)) + output = helper.make_tensor_value_info("output", cast_to, out_shape) + + axes_arr = np.array(axes, dtype=np.int64) + + initializers: list[onnx.TensorProto] = [] + nodes: list[onnx.NodeProto] = [] + if axes_as_initializer: + initializers.append( + helper.make_tensor("axes", TensorProto.INT64, list(axes_arr.shape), axes_arr.tolist()) + ) + nodes.append(helper.make_node("Unsqueeze", ["data", "axes"], ["unsq_out"], name="unsq")) + else: + nodes.append(helper.make_node("Unsqueeze", ["data", "axes_dyn"], ["unsq_out"], name="unsq")) + + nodes.append(helper.make_node("Cast", ["unsq_out"], ["output"], name="cast", to=cast_to)) + + graph_inputs = [data] + if not axes_as_initializer: + graph_inputs.append( + helper.make_tensor_value_info("axes_dyn", TensorProto.INT64, [len(axes)]) + ) + + graph = helper.make_graph( + nodes=nodes, + name="unsqueeze_cast_test", + inputs=graph_inputs, + outputs=[output], + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + onnx.checker.check_model(model) + return model + + +def _match(model: onnx.ModelProto) -> list: + matcher = PatternMatcher(model) + matcher.register_pattern(UnsqueezeCastPattern()) + return matcher.match() + + +class TestUnsqueezeCastPatternMatching: + """Topology / constraint behaviour of UnsqueezeCastPattern.""" + + def test_matches_float_cast(self) -> None: + model = _build_unsqueeze_cast_model() + results = _match(model) + assert len(results) == 1 + attrs = results[0].attributes + assert attrs["axes"] == (1,) + assert attrs["to"] == _FLOAT + + @pytest.mark.parametrize("axes", [(0,), (1,), (-1,), (0, 2)]) + def test_matches_various_axes(self, axes: tuple[int, ...]) -> None: + model = _build_unsqueeze_cast_model(data_shape=(2, 3, 4), axes=axes) + results = _match(model) + assert len(results) == 1 + assert results[0].attributes["axes"] == axes + + def test_rejects_non_float_cast(self) -> None: + """Cast(to=int32) must not match.""" + model = _build_unsqueeze_cast_model(cast_to=TensorProto.INT32) + results = _match(model) + assert results == [] + + def test_rejects_non_constant_axes(self) -> None: + """Unsqueeze with a dynamic (graph input) axes input must not match.""" + model = _build_unsqueeze_cast_model(axes_as_initializer=False) + results = _match(model) + assert results == [] + + def test_matches_float_input_too(self) -> None: + """Pattern is dtype-agnostic on the input side: float input also matches.""" + model = _build_unsqueeze_cast_model(data_elem_type=TensorProto.FLOAT) + results = _match(model) + assert len(results) == 1 + + +class TestUnsqueezeCastPatternRoundTrip: + """Self-matching via Pattern.get_onnx_model.""" + + def test_get_onnx_model_self_matches(self) -> None: + pattern = UnsqueezeCastPattern() + inputs = {"data": np.random.randn(2, 3).astype(np.float32)} + attributes = {"axes": (1,), "to": _FLOAT} + model = pattern.get_onnx_model( + inputs, + attributes, + is_constant_map={"data": False}, + output_dtypes=["tensor(float)"], + domain_versions=TEST_DOMAIN_VERSIONS, + ) + onnx.checker.check_model(model) + + results = _match(model) + assert len(results) == 1 + assert results[0].attributes["axes"] == (1,) + assert results[0].attributes["to"] == _FLOAT