Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/cortex_m/ops/op_quantized_conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ bool validate_conv2d_arguments(
executorch::aten::ArrayRef<executorch::aten::DimOrderType>
channels_last_order(kChannelsLastDimOrder, 4);

if (input.size(1) > 1 && input.dim_order() != channels_last_order) {
if (input.size(1) > 1 && !is_channels_last_tensor(input)) {
ET_LOG(
Error,
"quantized_conv2d_out: input must have channels_last dim_order (NHWC)");
context.fail(Error::InvalidArgument);
return false;
}

if (output.size(1) > 1 && output.dim_order() != channels_last_order) {
if (output.size(1) > 1 && !is_channels_last_tensor(input)) {
ET_LOG(
Error,
"quantized_conv2d_out: output must have channels_last dim_order (NHWC)");
Expand Down
32 changes: 23 additions & 9 deletions backends/cortex_m/ops/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,16 @@ def quantized_mul_impl(

@register_fake("cortex_m::minimum")
def minimum_meta(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
assert self.dtype == other.dtype, (
"Cortex-M minimum: dtype mismatch — "
f"got self.dtype={self.dtype}, other.dtype={other.dtype}"
)
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
return torch.empty(broadcasted_shape, dtype=self.dtype, device=self.device)
# other is a scalar, so use initial shape.
if other.numel() == 1:
return torch.empty_like(self)
else:
# otherwise broadcast the shape.
return torch.empty(
torch.broadcast_shapes(self.shape, other.shape),
dtype=self.dtype,
device=self.device,
)


@impl(lib, "minimum", "CompositeExplicitAutograd")
Expand All @@ -281,8 +285,16 @@ def maximum_meta(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
"Cortex-M maximum: dtype mismatch — "
f"got self.dtype={self.dtype}, other.dtype={other.dtype}"
)
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
return torch.empty(broadcasted_shape, dtype=self.dtype, device=self.device)
# other is a scalar, so use initial shape.
if other.numel() == 1:
return torch.empty_like(self)
else:
# otherwise broadcast the shape.
return torch.empty(
torch.broadcast_shapes(self.shape, other.shape),
dtype=self.dtype,
device=self.device,
)


@impl(lib, "maximum", "CompositeExplicitAutograd")
Expand Down Expand Up @@ -578,7 +590,9 @@ def quantized_conv2d_impl(
result += output_offset
result = torch.clamp(result, activation_min, activation_max)

return result.to(torch.int8)
# TODO - this enforces all convolution layers to result in channels last.
# This issue does comes min/mul layers from decomposition of hard swish
return result.to(torch.int8, memory_format=torch.channels_last)


# ===================================================================
Expand Down
1 change: 1 addition & 0 deletions backends/cortex_m/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .clamp_hardswish_pass import ClampHardswishPass # noqa
from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa
from .decompose_hardswish_pass import DecomposeHardswishPass # noqa
from .decompose_mean_pass import DecomposeMeanPass # noqa
from .quantized_op_fusion_pass import QuantizedOpFusionPass # noqa
from .replace_quant_nodes_pass import ReplaceQuantNodesPass # noqa
from .cortex_m_pass_manager import CortexMPassManager # noqa # usort: skip
2 changes: 2 additions & 0 deletions backends/cortex_m/passes/cortex_m_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ClampHardswishPass,
ConvertToCortexMPass,
DecomposeHardswishPass,
DecomposeMeanPass,
QuantizedOpFusionPass,
ReplaceQuantNodesPass,
)
Expand Down Expand Up @@ -43,6 +44,7 @@ class CortexMPassManager(PassManager):
ScalarsToAttributePass,
ReplaceScalarWithTensorArgPass,
ClampHardswishPass,
DecomposeMeanPass,
]

def __init__(self, exported_program, passes=None):
Expand Down
1 change: 1 addition & 0 deletions backends/cortex_m/passes/decompose_hardswish_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,6 @@ def call(self, graph_module: GraphModule) -> PassResult:
if modified:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified)
38 changes: 38 additions & 0 deletions backends/cortex_m/passes/decompose_mean_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import torch

from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
from torch.fx.node import Argument


class DecomposeMeanPass(ExportPass):
"""
Adds a clamp operation before hardswish to ensure input is in the range [-3, inf).

By doing this before quantization the output range of the preceeding op is minimized,
potentially improving accuracy.
"""

def call_operator(
self,
op: EdgeOpOverload,
args: tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op == torch.ops.aten.adaptive_avg_pool2d.default:
op = torch.ops.aten.avg_pool2d.default
input_tensor = args[0]
shape = input_tensor.data.shape
stride = [1, 1]
kernel_size = [shape[-2], shape[-1]]
args = (args[0], kernel_size, stride, [0, 0], 0, 0)

return super().call_operator(op, args, kwargs, meta)
2 changes: 1 addition & 1 deletion backends/cortex_m/passes/quantized_op_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _get_mul_replacement(self, args, meta):
return exir_ops.edge.cortex_m.quantized_mul.default, args

def _get_minimum_replacement(self, args, meta):
if args[0].data.dtype != torch.int8:
if args[0].data.dtype not in (torch.int8, torch.int32):
return exir_ops.edge.aten.minimum.default, args

return exir_ops.edge.cortex_m.minimum.default, args
Expand Down
2 changes: 2 additions & 0 deletions backends/cortex_m/quantizer/operator_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
# ----------------- OPERATOR PATTERN PRESETS -----------------
BINARY_OP_PATTERNS = [
[torch.ops.aten.add.Tensor],
[torch.ops.aten.add_.Tensor],
[torch.ops.aten.mul.Tensor],
[torch.ops.aten.mul_.Tensor],
[torch.ops.aten.hardswish.default],
[torch.ops.aten.hardswish_.default],
]
Expand Down
8 changes: 1 addition & 7 deletions backends/cortex_m/test/build_test_runner.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,7 @@ build_root_test_dir="${et_root_dir}/arm_test/arm_semihosting_executor_runner_cor

select_ops_list="\
aten::add.out,\
aten::clamp.out,\
aten::convolution.out,\
aten::div.out,\
aten::mean.out,\
aten::mul.out,\
aten::relu.out,\
aten::view_copy.out,\
dim_order_ops::_to_dim_order_copy.out"
aten::convolution.out"

${build_executor_runner} --pte=semihosting --target=ethos-u55-128 --output="${build_root_test_dir}" --select_ops_list="${select_ops_list}"
51 changes: 20 additions & 31 deletions backends/cortex_m/test/models/test_mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,40 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from executorch.backends.arm.test.common import parametrize

from executorch.backends.cortex_m.test.tester import CortexMTester, McuTestCase
from torchvision import models


# TODO: Update as more ops are converted by CMSIS-NN ops.
ops_before_transforms: dict[str, int] = {
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 34,
"executorch_exir_dialects_edge__ops_aten_addmm_default": 2,
"executorch_exir_dialects_edge__ops_aten_clamp_default": 56,
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 6,
"executorch_exir_dialects_edge__ops_aten_convolution_default": 52,
"executorch_exir_dialects_edge__ops_aten_div_Tensor": 28,
"executorch_exir_dialects_edge__ops_aten_mean_dim": 10,
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 28,
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 2,
"executorch_exir_dialects_edge__ops_aten_hardswish_default": 19,
"executorch_exir_dialects_edge__ops_aten_relu_default": 14,
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 56,
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 178,
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 109,
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 10,
"executorch_exir_dialects_edge__ops_aten_hardsigmoid_default": 9,
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 9,
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 6,
"executorch_exir_dialects_edge__ops_aten_linear_default": 2,
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 104,
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 120,
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 101,
}

ops_after_transforms: dict[str, int] = {
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 28, # Not lowered due to broadcasting
"executorch_exir_dialects_edge__ops_aten_addmm_default": 0,
"executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 6,
"executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 2,
"executorch_exir_dialects_edge__ops_aten_clamp_default": 56,
"executorch_exir_dialects_edge__ops_aten_convolution_default": 52,
"executorch_exir_dialects_edge__ops_aten_div_Tensor": 28,
"executorch_exir_dialects_edge__ops_aten_mean_dim": 10,
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 28,
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 0,
"executorch_exir_dialects_edge__ops_aten_relu_default": 14,
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 56,
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 0,
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 0,
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 162,
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 101,
"executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 52,
"executorch_exir_dialects_edge__ops_cortex_m_quantized_mul_default": 28,
"executorch_exir_dialects_edge__ops_cortex_m_quantized_avg_pool2d_default": 10,
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 2,
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2,
}

model = models.mobilenet_v3_small(weights=None)
example_input = torch.randn(1, 3, 224, 224)
example_input = torch.randn(1, 3, 224, 224).to(memory_format=torch.channels_last)


test_cases = {
Expand All @@ -58,7 +47,7 @@
}


@pytest.mark.skip("Skip until add + linear fix are upstreamed.")
@parametrize("test_case", test_cases)
def test_dialect_mv3(test_case):
tester = CortexMTester(test_case.model, test_case.example_inputs)
tester.test_dialect(
Expand All @@ -68,7 +57,7 @@ def test_dialect_mv3(test_case):
)


@pytest.mark.skip("Skip until add + linear fix are upstreamed.")
@parametrize("test_case", test_cases)
def test_implementation_mv3(test_case):
tester = CortexMTester(test_case.model, test_case.example_inputs)
tester.test_implementation(qtol=1)
Loading