Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions backends/arm/_passes/aten_to_tosa_activation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,21 @@ def rewrite_clamp(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec | Non
exir_ops.backend.tosa.CLAMP.default,
(node.args[0], *min_max_args),
)


def get_activation_replacement(
node: Node, pass_: AtenToDialectPass
) -> DialectNodeSpec | None:
# Dispatch activation rewrites from their ATen target to the matching TOSA
# dialect node builder.
match node.target:
case exir_ops.edge.aten.clamp.default:
return rewrite_clamp(node, pass_)
case exir_ops.edge.aten.erf.default:
return rewrite_erf(node, pass_)
case exir_ops.edge.aten.sigmoid.default:
return rewrite_sigmoid(node, pass_)
case exir_ops.edge.aten.tanh.default:
return rewrite_tanh(node, pass_)
case _:
return None
26 changes: 26 additions & 0 deletions backends/arm/_passes/aten_to_tosa_tensor_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast

from executorch.backends.transforms.aten_to_dialect_pass import (
AtenToDialectPass,
DialectNodeSpec,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx import Node


def rewrite_argmax(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec:
input_node = cast(Node, node.args[0])
dim = cast(int, node.kwargs["dim"] if "dim" in node.kwargs else node.args[1])
if dim < 0:
dim += len(input_node.meta["val"].shape)

return DialectNodeSpec(
exir_ops.backend.tosa.ARGMAX.default,
(input_node, dim),
{},
)
43 changes: 22 additions & 21 deletions backends/arm/_passes/exir_to_tosa_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,38 @@

import executorch.backends.arm.tosa.dialect # noqa: F401
from executorch.backends.arm._passes.aten_to_tosa_activation_functions import (
rewrite_clamp,
rewrite_erf,
rewrite_sigmoid,
rewrite_tanh,
get_activation_replacement,
)
from executorch.backends.arm._passes.aten_to_tosa_tensor_operators import rewrite_argmax
from executorch.backends.transforms.aten_to_dialect_pass import (
AtenToDialectPass,
DialectNodeSpec,
)
from executorch.backends.transforms.aten_to_dialect_pass import AtenToDialectPass
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx import Node


class ExirToTosaPass(AtenToDialectPass):
"""Rewrite simple EXIR ops to equivalent backend TOSA dialect ops.

Rewrite functions are grouped by op category and registered with the shared
ATen-to-dialect pass infrastructure.
Rewrite functions are registered with the shared ATen-to-dialect pass
infrastructure.

"""


_ACTIVATION_FUNCTION_REWRITES = {
exir_ops.edge.aten.clamp.default: rewrite_clamp,
exir_ops.edge.aten.erf.default: rewrite_erf,
exir_ops.edge.aten.sigmoid.default: rewrite_sigmoid,
exir_ops.edge.aten.tanh.default: rewrite_tanh,
}
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.argmax.default)
def _get_tensor_operators_replacement(
node: Node, pass_: AtenToDialectPass
) -> DialectNodeSpec:
return rewrite_argmax(node, pass_)

_DIRECT_REWRITE_CATEGORIES = {
"activation_functions": _ACTIVATION_FUNCTION_REWRITES,
}

# Register each category's ATen targets with the function that builds the
# corresponding TOSA dialect node spec.
for _rewrite_category in _DIRECT_REWRITE_CATEGORIES.values():
for _edge_target, _rewrite_fn in _rewrite_category.items():
ExirToTosaPass.register_dialect_substitution(_edge_target)(_rewrite_fn)
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.clamp.default)
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.erf.default)
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.sigmoid.default)
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.tanh.default)
def _get_activation_replacement(
node: Node, pass_: AtenToDialectPass
) -> DialectNodeSpec | None:
return get_activation_replacement(node, pass_)
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.aten.pad.default,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.argmax.default,
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.amin.default,
exir_ops.edge.aten.eye.default,
Expand Down Expand Up @@ -238,6 +239,7 @@
operator.getitem,
exir_ops.edge.aten.pad.default,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.argmax.default,
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.amin.default,
exir_ops.edge.aten.eye.default,
Expand Down
112 changes: 105 additions & 7 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _negative_checks(
checks: list[OperatorSupportBase] = [RankCheck(reporter, MAX_RANK)]

if not tosa_spec.support_extension("int64"):
checks.append(CheckInt64InputsAndOutputs(exported_program, reporter))
checks.append(CheckInt64InputsAndOutputs(exported_program, reporter, tosa_spec))

checks.extend(_wrapped_additional_checks(additional_checks, reporter))

Expand Down Expand Up @@ -677,7 +677,10 @@ class CheckInt64InputsAndOutputs(OperatorSupportBase):
"""

def __init__(
self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
self,
exported_program: ExportedProgram,
reporter: WhyNoPartitionReporter,
tosa_spec: TosaSpecification,
):
"""Initialize the check with program context and reporter."""
self.input_names = [
Expand All @@ -686,6 +689,7 @@ def __init__(
if spec.kind == InputKind.USER_INPUT
]
self.reporter = reporter
self.tosa_spec = tosa_spec
self.int32_min = torch.iinfo(torch.int32).min
self.int32_max = torch.iinfo(torch.int32).max
super().__init__()
Expand All @@ -698,6 +702,104 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool:
min_val, max_val = int(torch.min(data)), int(torch.max(data))
return min_val >= self.int32_min and max_val <= self.int32_max

def has_rejected_int64_output(
self, node: torch.fx.Node, tensor_list: Sequence[typing.Any]
) -> bool:
if node.target in (
torch.ops.aten.argmax.default,
exir_ops.edge.aten.argmax.default,
):
return not self._is_tosa_argmax_supported(node)
return any(
tensor.dtype == torch.int64
for tensor in tensor_list
if isinstance(tensor, FakeTensor)
)

def _is_tosa_argmax_dtype_supported(
self, node: torch.fx.Node, input_dtype: torch.dtype
) -> bool:
if input_dtype == torch.int8:
if not self.tosa_spec.support_integer():
self.reporter.report_reject(
node, "TOSA ARGMAX requires PRO-INT for int8 input."
)
return False
elif input_dtype == torch.int16:
if not (
self.tosa_spec.support_integer()
and self.tosa_spec.support_extension("int16")
):
self.reporter.report_reject(
node, "TOSA ARGMAX requires EXT-INT16 for int16 input."
)
return False
elif input_dtype in (torch.float16, torch.float32):
if not self.tosa_spec.support_float():
self.reporter.report_reject(
node, f"TOSA ARGMAX requires PRO-FP for {input_dtype} input."
)
return False
elif input_dtype == torch.bfloat16:
if not (
self.tosa_spec.support_float()
and self.tosa_spec.support_extension("bf16")
):
self.reporter.report_reject(
node, "TOSA ARGMAX requires EXT-BF16 for bfloat16 input."
)
return False
else:
self.reporter.report_reject(
node, f"TOSA ARGMAX does not support {input_dtype} input."
)
return False
return True

def _is_tosa_argmax_supported(self, node: torch.fx.Node) -> bool:
dim = node.kwargs.get("dim", node.args[1] if len(node.args) > 1 else None)
if dim is None:
self.reporter.report_reject(
node, "TOSA ARGMAX requires an explicit reduction dimension."
)
return False
if not isinstance(dim, int):
self.reporter.report_reject(
node, "TOSA ARGMAX requires a statically known reduction dimension."
)
return False

input_node = typing.cast(torch.fx.Node, node.args[0])
input_tensor = get_first_fake_tensor(input_node)
if not self._is_tosa_argmax_dtype_supported(node, input_tensor.dtype):
return False

input_rank = len(input_tensor.shape)
if input_rank == 0:
self.reporter.report_reject(
node, "TOSA ARGMAX requires an input with rank at least 1."
)
return False

axis = dim + input_rank if dim < 0 else dim
if axis < 0 or axis >= input_rank:
self.reporter.report_reject(
node,
f"TOSA ARGMAX axis must be in [0, {input_rank - 1}] but got {dim}.",
)
return False

keepdim = node.kwargs.get(
"keepdim", node.args[2] if len(node.args) > 2 else False
)
if keepdim:
self.reporter.report_reject(
node, "TOSA ARGMAX does not support keepdim=True."
)
return False

return True

def _check_int64_input_nodes(self, node: torch.fx.Node) -> bool:
"""Check if all int64 input nodes are constant and will be
partitioned.
Expand Down Expand Up @@ -741,11 +843,7 @@ def is_node_supported(
vals = node.meta["val"]
tensor_list = vals if isinstance(vals, (list, tuple)) else [vals]

any_int64 = any(
tensor.dtype == torch.int64
for tensor in tensor_list
if isinstance(tensor, FakeTensor)
)
any_int64 = self.has_rejected_int64_output(node, tensor_list)
# Don't partition nodes with int64 output...
if any_int64:
# ... Except for constant ops that are directly cast to something non-int64.
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
op_sub,
op_sum,
op_to_dim_order_copy,
op_tosa_argmax,
op_tosa_avg_pool2d,
op_tosa_avg_pool2d_adaptive,
op_tosa_cast_to_block_scaled,
Expand Down
63 changes: 63 additions & 0 deletions backends/arm/operators/op_tosa_argmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List

import torch.fx
import tosa_serializer as ts

from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_valid_dtype,
)
from executorch.backends.arm.tosa.mapping import TosaArg


@register_node_visitor
class ArgMaxVisitor(NodeVisitor):
target = "tosa.ARGMAX.default"

def define_node(
self,
node: torch.fx.Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
validate_num_inputs(self.target, inputs, 2)
validate_valid_dtype(
self.target,
inputs[0],
[
ts.DType.INT8,
ts.DType.INT16,
ts.DType.FP16,
ts.DType.FP32,
ts.DType.BF16,
],
self.tosa_spec,
)
validate_valid_dtype(self.target, output, ts.DType.INT32, self.tosa_spec)

axis = inputs[1].number
if axis < 0:
tensor = get_first_fake_tensor(node)
axis += len(tensor.size())

attr = ts.TosaSerializerAttribute()
attr.ArgMaxAttribute(axis, ts.NanPropagationMode.PROPAGATE)
self._serialize_operator(
node,
tosa_graph,
ts.Op.ARGMAX,
[inputs[0].name],
[output.name],
attr,
)
Loading
Loading