Skip to content

Commit 2a7e17e

Browse files
authored
Cortex_M backend:: Add support for AvgPool2d (#16178)
- Adds support for avg_pool2d to Cortex M backend. Without support for ceil_mode and count_include_pad. - Adds filter function to quantizer to ensure if ceil_mode or count_include_pad are True avg_pool2d is not quantized. Signed-off-by: Saoirse Stewart <[email protected]>
1 parent 28da6a8 commit 2a7e17e

File tree

8 files changed

+305
-2
lines changed

8 files changed

+305
-2
lines changed

backends/cortex_m/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ set(_cortex_m_kernels__srcs
5757
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
5858
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp
5959
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_conv2d.cpp
60+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_avg_pool2d.cpp
6061
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp
6162
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp
6263
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright 2025 Arm Limited and/or its affiliates.
3+
*
4+
* This source code is licensed under the BSD-style license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include "cortex_m_ops_common.h"
9+
10+
extern "C" {
11+
#include "arm_nnfunctions.h"
12+
}
13+
14+
namespace cortex_m {
15+
namespace native {
16+
17+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
18+
19+
Tensor& quantized_avg_pool2d_out(
20+
KernelRuntimeContext& context,
21+
const Tensor& input,
22+
const IntArrayRef kernel_size,
23+
const IntArrayRef stride,
24+
const IntArrayRef padding,
25+
const Scalar& zero_point,
26+
const Scalar& multiplier,
27+
const Scalar& shift,
28+
Tensor& out) {
29+
if (input.dim() != 4 || out.dim() != 4) {
30+
ET_LOG(Error, "quantized_avg_pool2d_out: tensors must be 4-D");
31+
context.fail(Error::InvalidArgument);
32+
return out;
33+
}
34+
int32_t batch = static_cast<int32_t>(input.size(0));
35+
int32_t channels = static_cast<int32_t>(input.size(1));
36+
int32_t input_h = static_cast<int32_t>(input.size(2));
37+
int32_t input_w = static_cast<int32_t>(input.size(3));
38+
int32_t kernel_h = static_cast<int32_t>(kernel_size[0]);
39+
int32_t kernel_w = static_cast<int32_t>(kernel_size[1]);
40+
int32_t stride_h = static_cast<int32_t>(stride[0]);
41+
int32_t stride_w = static_cast<int32_t>(stride[1]);
42+
int32_t pad_h = static_cast<int32_t>(padding[0]);
43+
int32_t pad_w = static_cast<int32_t>(padding[1]);
44+
int32_t output_h = static_cast<int32_t>(out.size(2));
45+
int32_t output_w = static_cast<int32_t>(out.size(3));
46+
const int32_t activation_min = std::numeric_limits<int8_t>::min();
47+
const int32_t activation_max = std::numeric_limits<int8_t>::max();
48+
49+
const int8_t* input_data = input.const_data_ptr<int8_t>();
50+
int8_t* output_data = out.mutable_data_ptr<int8_t>();
51+
52+
cmsis_nn_context cmsis_ctx;
53+
cmsis_ctx.buf = nullptr;
54+
cmsis_ctx.size = 0;
55+
cmsis_nn_pool_params pool_params;
56+
pool_params.stride.h = stride_h;
57+
pool_params.stride.w = stride_w;
58+
pool_params.padding.h = pad_h;
59+
pool_params.padding.w = pad_w;
60+
pool_params.activation.min = activation_min;
61+
pool_params.activation.max = activation_max;
62+
63+
cmsis_nn_dims input_dims{batch, input_h, input_w, channels};
64+
cmsis_nn_dims filter_dims{1, kernel_h, kernel_w, 1};
65+
cmsis_nn_dims output_dims{batch, output_h, output_w, channels};
66+
67+
arm_cmsis_nn_status status = arm_avgpool_s8(
68+
&cmsis_ctx,
69+
&pool_params,
70+
&input_dims,
71+
input_data,
72+
&filter_dims,
73+
&output_dims,
74+
output_data);
75+
if (status != ARM_CMSIS_NN_SUCCESS) {
76+
ET_LOG(
77+
Error,
78+
"quantized_avg_pool2d_out: arm_avgpool_s8 failed with status [%d]",
79+
status);
80+
context.fail(Error::Internal);
81+
}
82+
return out;
83+
}
84+
85+
} // namespace native
86+
} // namespace cortex_m

backends/cortex_m/ops/operators.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
import torch
1212
import torch.nn.functional as F
1313
from executorch.backends.cortex_m.passes.passes_utils import (
14+
dequantize_per_tensor_cmsis,
1415
is_channel_broadcast,
16+
quantize_per_tensor_cmsis,
1517
requantize_cmsis,
1618
SHIFT_INT8,
1719
)
@@ -577,3 +579,74 @@ def quantized_conv2d_impl(
577579
result = torch.clamp(result, activation_min, activation_max)
578580

579581
return result.to(torch.int8)
582+
583+
584+
# ===================================================================
585+
# QUANTIZED AVG_POOL2D OPERATION DEFINITION
586+
# ===================================================================
587+
588+
lib.define(
589+
"quantized_avg_pool2d("
590+
"Tensor input, "
591+
"int[] kernel_size, "
592+
"int[] stride, "
593+
"int[] padding, "
594+
"Scalar zero_point, "
595+
"Scalar multiplier, "
596+
"Scalar shift"
597+
") -> Tensor"
598+
)
599+
lib.define(
600+
"quantized_avg_pool2d.out("
601+
"Tensor input, "
602+
"int[] kernel_size, "
603+
"int[] stride, "
604+
"int[] padding, "
605+
"Scalar zero_point, "
606+
"Scalar multiplier, "
607+
"Scalar shift, "
608+
"*, Tensor(a!) out) -> Tensor(a!)"
609+
)
610+
611+
612+
@register_fake("cortex_m::quantized_avg_pool2d")
613+
def quantized_avg_pool2d_meta(
614+
input: torch.Tensor,
615+
kernel_size: Sequence[int],
616+
stride: Sequence[int],
617+
padding: Sequence[int],
618+
zero_point: int,
619+
multiplier: int,
620+
shift: int,
621+
) -> torch.Tensor:
622+
# Compute output shape as in PyTorch avg_pool2d
623+
624+
output = F.avg_pool2d(input, kernel_size, stride, padding)
625+
return torch.empty_like(output, dtype=torch.int8)
626+
627+
628+
@impl(lib, "quantized_avg_pool2d", "CompositeExplicitAutograd")
629+
def quantized_avg_pool2d_impl(
630+
input: torch.Tensor,
631+
kernel_size: Sequence[int],
632+
stride: Sequence[int],
633+
padding: Sequence[int],
634+
zero_point: int,
635+
multiplier: int,
636+
shift: int,
637+
) -> torch.Tensor:
638+
639+
dequant_input = dequantize_per_tensor_cmsis(input, zero_point, multiplier, shift)
640+
641+
# TODO: implement count_include_pad=True, ceil_mode=True.
642+
result = F.avg_pool2d(
643+
dequant_input,
644+
kernel_size,
645+
stride=stride,
646+
padding=padding,
647+
count_include_pad=False,
648+
ceil_mode=False,
649+
)
650+
result = quantize_per_tensor_cmsis(result, zero_point, multiplier, shift)
651+
output = torch.clamp(result, -128, 127)
652+
return output.to(torch.int8)

backends/cortex_m/ops/operators.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,9 @@
5858
kernels:
5959
- arg_meta: null
6060
kernel_name: cortex_m::quantized_conv2d_out
61+
62+
- func: cortex_m::quantized_avg_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, Scalar zero_point, Scalar multiplier, Scalar shift, *, Tensor(a!) out) -> Tensor(a!)
63+
variants: function
64+
kernels:
65+
- arg_meta: null
66+
kernel_name: cortex_m::quantized_avg_pool2d_out

backends/cortex_m/passes/quantized_op_fusion_pass.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,27 @@ def _get_permute_replacement(self, args, meta):
122122
args = (args[0], perms)
123123
return exir_ops.edge.cortex_m.transpose.default, args
124124

125+
def _get_avg_pool2d_replacement(self, args, meta):
126+
if (
127+
meta.data.get("input_qparams", {}) == {}
128+
or meta.data.get("output_qparams", {}) == {}
129+
):
130+
return exir_ops.edge.aten.avg_pool2d.default, args
131+
132+
# Extract values
133+
scale = meta["input_qparams"][0].scale
134+
zero_point = meta["input_qparams"][0].zp
135+
136+
output_mult, output_shift = quantize_multiplier_aot(scale)
137+
args = (
138+
*args[0:-2],
139+
zero_point,
140+
output_mult,
141+
output_shift,
142+
)
143+
144+
return exir_ops.edge.cortex_m.quantized_avg_pool2d.default, args
145+
125146
def call_operator(
126147
self,
127148
op: EdgeOpOverload,
@@ -141,6 +162,8 @@ def call_operator(
141162
op, args = self._get_maximum_replacement(args, meta)
142163
case exir_ops.edge.aten.permute_copy.default:
143164
op, args = self._get_permute_replacement(args, meta)
165+
case exir_ops.edge.aten.avg_pool2d.default:
166+
op, args = self._get_avg_pool2d_replacement(args, meta)
144167
case _:
145168
pass
146169

backends/cortex_m/quantizer/quantizer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
from typing import Callable, List, Optional
7+
from typing import Callable, cast, List, Optional
88

99
import torch
1010
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
@@ -315,6 +315,7 @@ class SharedQspecQuantizer(Quantizer):
315315
# Min/Max/Mean
316316
torch.ops.aten.minimum.default,
317317
torch.ops.aten.maximum.default,
318+
torch.ops.aten.avg_pool2d.default,
318319
# Data shuffling
319320
torch.ops.aten.permute.default,
320321
torch.ops.aten.permute_copy.default,
@@ -402,7 +403,20 @@ def _annotate_shared_cluster(self, root_node: Node) -> None:
402403
mark_node_as_annotated(node, input_qspec_map, shared_qspec)
403404

404405
def annotate(self, model: GraphModule) -> None:
406+
"""
407+
Annotate shared quantization spec for supported ops, but skip avg_pool2d
408+
when both ceil_mode and count_include_pad are True.
409+
"""
405410
for node in model.graph.nodes:
411+
# TODO Skip avg_pool2d when ceil_mode=True or count_include_pad=True
412+
# CMSIS-NN doesn't directly support this. But, it should be done.
413+
if node.target is torch.ops.aten.avg_pool2d.default:
414+
ceil_mode = cast(bool, node.args[4]) if len(node.args) > 4 else False
415+
count_include_pad = (
416+
cast(bool, node.args[5]) if len(node.args) > 5 else True
417+
)
418+
if ceil_mode or count_include_pad:
419+
continue
406420
if node.target in self.targets and not self._is_annotated(node):
407421
self._annotate_shared_cluster(node)
408422

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm.test.common import parametrize
8+
from executorch.backends.cortex_m.test.tester import (
9+
CortexMTester,
10+
McuTestCase,
11+
ramp_tensor,
12+
)
13+
14+
15+
class CortexMAvgPool2d(torch.nn.Module):
16+
ops_before_transforms = {
17+
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1,
18+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
19+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2,
20+
}
21+
22+
ops_after_transforms = {
23+
"executorch_exir_dialects_edge__ops_cortex_m_quantized_avg_pool2d_default": 1,
24+
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1,
25+
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
26+
}
27+
28+
def __init__(
29+
self, kernel_size, stride, padding=0, ceil_mode=False, count_include_pad=False
30+
):
31+
super().__init__()
32+
self.pool = torch.nn.AvgPool2d(
33+
kernel_size,
34+
stride,
35+
padding,
36+
ceil_mode=ceil_mode,
37+
count_include_pad=count_include_pad,
38+
)
39+
40+
def forward(self, x): # noqa: D102
41+
return self.pool(x)
42+
43+
44+
# Prepare test cases: simple 2x2 pool on 4x4, and 3x3 stride 1 on 3x3
45+
test_cases = {
46+
"avgpool_2x2": McuTestCase(
47+
CortexMAvgPool2d(kernel_size=2, stride=2), (ramp_tensor(0, 15, (1, 1, 4, 4)),)
48+
),
49+
"avgpool_3x3_s1": McuTestCase(
50+
CortexMAvgPool2d(kernel_size=3, stride=1, padding=1),
51+
(ramp_tensor(0, 8, (1, 1, 3, 3)),),
52+
),
53+
# additional pooling configurations: padding, stride, ceil_mode, count_include_pad
54+
"avgpool_2x2_pad1": McuTestCase(
55+
CortexMAvgPool2d(kernel_size=2, stride=2, padding=1),
56+
(ramp_tensor(0, 24, (1, 1, 5, 5)),),
57+
),
58+
"avgpool_3x3_s2_pad1": McuTestCase(
59+
CortexMAvgPool2d(kernel_size=3, stride=2, padding=1),
60+
(ramp_tensor(0, 15, (1, 1, 4, 4)),),
61+
),
62+
}
63+
64+
test_cases_fp = {
65+
"avgpool_3x3_s2_pad1_ceil": McuTestCase(
66+
CortexMAvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True),
67+
(ramp_tensor(0, 15, (1, 1, 4, 4)),),
68+
),
69+
"avgpool_3x3_s2_pad1_countinc": McuTestCase(
70+
CortexMAvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=True),
71+
(ramp_tensor(0, 15, (1, 1, 4, 4)),),
72+
),
73+
}
74+
75+
76+
@parametrize("test_case", test_cases)
77+
def test_dialect_avg_pool2d(test_case):
78+
tester = CortexMTester(test_case.model, test_case.example_inputs)
79+
tester.test_dialect(
80+
test_case.model.ops_before_transforms,
81+
test_case.model.ops_after_transforms,
82+
qtol=1,
83+
)
84+
85+
86+
@parametrize("test_case", test_cases_fp)
87+
def test_dialect_avg_pool2d_fp(test_case):
88+
tester = CortexMTester(test_case.model, test_case.example_inputs)
89+
tester.test_dialect(
90+
{"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1},
91+
{"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1},
92+
qtol=1,
93+
)
94+
95+
96+
@parametrize("test_case", test_cases)
97+
def test_implementation_avg_pool2d(test_case):
98+
tester = CortexMTester(test_case.model, test_case.example_inputs)
99+
tester.test_implementation(qtol=1)

backends/cortex_m/test/tester.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def __init__(self):
4141
torch.ops.aten.hardsigmoid_.default,
4242
torch.ops.aten.hardswish.default,
4343
torch.ops.aten.hardswish_.default,
44-
]
44+
],
45+
_check_ir_validity=False,
4546
)
4647
super().__init__(config)
4748

0 commit comments

Comments
 (0)