diff --git a/backends/cortex_m/ops/op_quantized_conv2d.cpp b/backends/cortex_m/ops/op_quantized_conv2d.cpp index ad14af98865..ed8f7ed10b1 100644 --- a/backends/cortex_m/ops/op_quantized_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_conv2d.cpp @@ -44,7 +44,7 @@ bool validate_conv2d_arguments( executorch::aten::ArrayRef channels_last_order(kChannelsLastDimOrder, 4); - if (input.size(1) > 1 && input.dim_order() != channels_last_order) { + if (input.size(1) > 1 && !is_channels_last_tensor(input)) { ET_LOG( Error, "quantized_conv2d_out: input must have channels_last dim_order (NHWC)"); @@ -52,7 +52,7 @@ bool validate_conv2d_arguments( return false; } - if (output.size(1) > 1 && output.dim_order() != channels_last_order) { + if (output.size(1) > 1 && !is_channels_last_tensor(input)) { ET_LOG( Error, "quantized_conv2d_out: output must have channels_last dim_order (NHWC)"); diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index a33703489fd..9c80e301c20 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -258,12 +258,16 @@ def quantized_mul_impl( @register_fake("cortex_m::minimum") def minimum_meta(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: - assert self.dtype == other.dtype, ( - "Cortex-M minimum: dtype mismatch — " - f"got self.dtype={self.dtype}, other.dtype={other.dtype}" - ) - broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) - return torch.empty(broadcasted_shape, dtype=self.dtype, device=self.device) + # other is a scalar, so use initial shape. + if other.numel() == 1: + return torch.empty_like(self) + else: + # otherwise broadcast the shape. + return torch.empty( + torch.broadcast_shapes(self.shape, other.shape), + dtype=self.dtype, + device=self.device, + ) @impl(lib, "minimum", "CompositeExplicitAutograd") @@ -281,8 +285,16 @@ def maximum_meta(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: "Cortex-M maximum: dtype mismatch — " f"got self.dtype={self.dtype}, other.dtype={other.dtype}" ) - broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) - return torch.empty(broadcasted_shape, dtype=self.dtype, device=self.device) + # other is a scalar, so use initial shape. + if other.numel() == 1: + return torch.empty_like(self) + else: + # otherwise broadcast the shape. + return torch.empty( + torch.broadcast_shapes(self.shape, other.shape), + dtype=self.dtype, + device=self.device, + ) @impl(lib, "maximum", "CompositeExplicitAutograd") @@ -578,7 +590,9 @@ def quantized_conv2d_impl( result += output_offset result = torch.clamp(result, activation_min, activation_max) - return result.to(torch.int8) + # TODO - this enforces all convolution layers to result in channels last. + # This issue does comes min/mul layers from decomposition of hard swish + return result.to(torch.int8, memory_format=torch.channels_last) # =================================================================== diff --git a/backends/cortex_m/passes/__init__.py b/backends/cortex_m/passes/__init__.py index c8bb743e278..1b1b0363123 100644 --- a/backends/cortex_m/passes/__init__.py +++ b/backends/cortex_m/passes/__init__.py @@ -7,6 +7,7 @@ from .clamp_hardswish_pass import ClampHardswishPass # noqa from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa from .decompose_hardswish_pass import DecomposeHardswishPass # noqa +from .decompose_mean_pass import DecomposeMeanPass # noqa from .quantized_op_fusion_pass import QuantizedOpFusionPass # noqa from .replace_quant_nodes_pass import ReplaceQuantNodesPass # noqa from .cortex_m_pass_manager import CortexMPassManager # noqa # usort: skip diff --git a/backends/cortex_m/passes/cortex_m_pass_manager.py b/backends/cortex_m/passes/cortex_m_pass_manager.py index bd3fad1cf94..93ca7222af0 100644 --- a/backends/cortex_m/passes/cortex_m_pass_manager.py +++ b/backends/cortex_m/passes/cortex_m_pass_manager.py @@ -15,6 +15,7 @@ ClampHardswishPass, ConvertToCortexMPass, DecomposeHardswishPass, + DecomposeMeanPass, QuantizedOpFusionPass, ReplaceQuantNodesPass, ) @@ -43,6 +44,7 @@ class CortexMPassManager(PassManager): ScalarsToAttributePass, ReplaceScalarWithTensorArgPass, ClampHardswishPass, + DecomposeMeanPass, ] def __init__(self, exported_program, passes=None): diff --git a/backends/cortex_m/passes/decompose_hardswish_pass.py b/backends/cortex_m/passes/decompose_hardswish_pass.py index 36ca6bd759d..3fa854a7779 100644 --- a/backends/cortex_m/passes/decompose_hardswish_pass.py +++ b/backends/cortex_m/passes/decompose_hardswish_pass.py @@ -123,5 +123,6 @@ def call(self, graph_module: GraphModule) -> PassResult: if modified: graph_module.graph.eliminate_dead_code() graph_module.recompile() + graph_module = super().call(graph_module).graph_module return PassResult(graph_module, modified) diff --git a/backends/cortex_m/passes/decompose_mean_pass.py b/backends/cortex_m/passes/decompose_mean_pass.py new file mode 100644 index 00000000000..588f0356508 --- /dev/null +++ b/backends/cortex_m/passes/decompose_mean_pass.py @@ -0,0 +1,38 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch + +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from torch.fx.node import Argument + + +class DecomposeMeanPass(ExportPass): + """ + Adds a clamp operation before hardswish to ensure input is in the range [-3, inf). + + By doing this before quantization the output range of the preceeding op is minimized, + potentially improving accuracy. + """ + + def call_operator( + self, + op: EdgeOpOverload, + args: tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op == torch.ops.aten.adaptive_avg_pool2d.default: + op = torch.ops.aten.avg_pool2d.default + input_tensor = args[0] + shape = input_tensor.data.shape + stride = [1, 1] + kernel_size = [shape[-2], shape[-1]] + args = (args[0], kernel_size, stride, [0, 0], 0, 0) + + return super().call_operator(op, args, kwargs, meta) diff --git a/backends/cortex_m/passes/quantized_op_fusion_pass.py b/backends/cortex_m/passes/quantized_op_fusion_pass.py index d582ec08e47..38062a10673 100644 --- a/backends/cortex_m/passes/quantized_op_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_op_fusion_pass.py @@ -102,7 +102,7 @@ def _get_mul_replacement(self, args, meta): return exir_ops.edge.cortex_m.quantized_mul.default, args def _get_minimum_replacement(self, args, meta): - if args[0].data.dtype != torch.int8: + if args[0].data.dtype not in (torch.int8, torch.int32): return exir_ops.edge.aten.minimum.default, args return exir_ops.edge.cortex_m.minimum.default, args diff --git a/backends/cortex_m/quantizer/operator_configs.py b/backends/cortex_m/quantizer/operator_configs.py index dadee30fa41..19cd1f1ff54 100644 --- a/backends/cortex_m/quantizer/operator_configs.py +++ b/backends/cortex_m/quantizer/operator_configs.py @@ -18,7 +18,9 @@ # ----------------- OPERATOR PATTERN PRESETS ----------------- BINARY_OP_PATTERNS = [ [torch.ops.aten.add.Tensor], + [torch.ops.aten.add_.Tensor], [torch.ops.aten.mul.Tensor], + [torch.ops.aten.mul_.Tensor], [torch.ops.aten.hardswish.default], [torch.ops.aten.hardswish_.default], ] diff --git a/backends/cortex_m/test/build_test_runner.sh b/backends/cortex_m/test/build_test_runner.sh index bf29b21d310..a66970861f6 100755 --- a/backends/cortex_m/test/build_test_runner.sh +++ b/backends/cortex_m/test/build_test_runner.sh @@ -21,13 +21,7 @@ build_root_test_dir="${et_root_dir}/arm_test/arm_semihosting_executor_runner_cor select_ops_list="\ aten::add.out,\ -aten::clamp.out,\ -aten::convolution.out,\ -aten::div.out,\ -aten::mean.out,\ aten::mul.out,\ -aten::relu.out,\ -aten::view_copy.out,\ -dim_order_ops::_to_dim_order_copy.out" +aten::convolution.out" ${build_executor_runner} --pte=semihosting --target=ethos-u55-128 --output="${build_root_test_dir}" --select_ops_list="${select_ops_list}" diff --git a/backends/cortex_m/test/models/test_mobilenet_v3.py b/backends/cortex_m/test/models/test_mobilenet_v3.py index 598d71ed212..764a99617d3 100644 --- a/backends/cortex_m/test/models/test_mobilenet_v3.py +++ b/backends/cortex_m/test/models/test_mobilenet_v3.py @@ -3,51 +3,40 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import pytest import torch +from executorch.backends.arm.test.common import parametrize from executorch.backends.cortex_m.test.tester import CortexMTester, McuTestCase from torchvision import models -# TODO: Update as more ops are converted by CMSIS-NN ops. ops_before_transforms: dict[str, int] = { - "executorch_exir_dialects_edge__ops_aten_add_Tensor": 34, - "executorch_exir_dialects_edge__ops_aten_addmm_default": 2, - "executorch_exir_dialects_edge__ops_aten_clamp_default": 56, + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 6, "executorch_exir_dialects_edge__ops_aten_convolution_default": 52, - "executorch_exir_dialects_edge__ops_aten_div_Tensor": 28, - "executorch_exir_dialects_edge__ops_aten_mean_dim": 10, - "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 28, - "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 2, + "executorch_exir_dialects_edge__ops_aten_hardswish_default": 19, "executorch_exir_dialects_edge__ops_aten_relu_default": 14, - "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, - "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 56, - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 178, - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 109, + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 10, + "executorch_exir_dialects_edge__ops_aten_hardsigmoid_default": 9, + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 9, + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 6, + "executorch_exir_dialects_edge__ops_aten_linear_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 104, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 120, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 101, } + ops_after_transforms: dict[str, int] = { - "executorch_exir_dialects_edge__ops_aten_add_Tensor": 28, # Not lowered due to broadcasting - "executorch_exir_dialects_edge__ops_aten_addmm_default": 0, "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 6, "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 2, - "executorch_exir_dialects_edge__ops_aten_clamp_default": 56, - "executorch_exir_dialects_edge__ops_aten_convolution_default": 52, - "executorch_exir_dialects_edge__ops_aten_div_Tensor": 28, - "executorch_exir_dialects_edge__ops_aten_mean_dim": 10, - "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 28, - "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 0, - "executorch_exir_dialects_edge__ops_aten_relu_default": 14, - "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, - "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 56, - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 0, - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 0, - "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 162, - "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 101, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 52, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_mul_default": 28, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_avg_pool2d_default": 10, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, } model = models.mobilenet_v3_small(weights=None) -example_input = torch.randn(1, 3, 224, 224) +example_input = torch.randn(1, 3, 224, 224).to(memory_format=torch.channels_last) test_cases = { @@ -58,7 +47,7 @@ } -@pytest.mark.skip("Skip until add + linear fix are upstreamed.") +@parametrize("test_case", test_cases) def test_dialect_mv3(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_dialect( @@ -68,7 +57,7 @@ def test_dialect_mv3(test_case): ) -@pytest.mark.skip("Skip until add + linear fix are upstreamed.") +@parametrize("test_case", test_cases) def test_implementation_mv3(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_implementation(qtol=1)