Skip to content

Commit ab02129

Browse files
martinlsmMartin Lindström
andauthored
Arm backend: Partition boundary Q/DQ nodes for INT+FP (#16312)
For full INT lowering without FP support, Q/DQ nodes that are on the boundary of a partition are not included in the partition. With INT+FP support, the backend is able to handle them properly. Therefore, include these boundary nodes in that setting. To add test coverage for this new feature, a new stage called "check_not.exir_quant_nodes" is added for TosaPipelineINT and VgfPipeline in case both FP and INT profiles are enabled. This stage verifies that no exir Q/DQ remains in the graph after "to_edge_transform_and_lower" (or "partition" if "to_edge_transform_and_lower" is omitted). In case a test fails to partition the boundary Q/DQ nodes in INT+FP lowering, it will be detected in "check_not.exir_quant_nodes". Tests in test_quant_custom_meta.py will run this new check. Signed-off-by: Martin Lindström <[email protected]> Co-authored-by: Martin Lindström <[email protected]>
1 parent 7eb43bb commit ab02129

File tree

4 files changed

+120
-46
lines changed

4 files changed

+120
-46
lines changed

backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ def call_operator(self, op, args, kwargs, meta):
5858
if not (input_dtype == torch.float32 and output_dtype == torch.int32):
5959
return super().call_operator(op, args, kwargs, meta)
6060

61+
# For some ops, qparams dtype is inconsistent with fake tensor's dtype.
62+
# Skip decorating if the input is quantized and thus not floating point.
63+
if (
64+
"output_qparams" in input.node.meta
65+
and len(input.node.meta["output_qparams"]) > 0
66+
):
67+
return super().call_operator(op, args, kwargs, meta)
68+
6169
op_full, op_ge, op_floor, op_ceil, op_where = _get_decorated_ops(op)
6270

6371
zero = super().call_operator(

backends/arm/test/misc/test_mixed_type_lowering.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ def combine_op_dicts(*dicts):
1717
return {op: dict(counts) for op, counts in merged.items()}
1818

1919

20+
def repeat_op_dict(op_dict, times):
21+
repeated = {}
22+
for op, dtypes in op_dict.items():
23+
repeated[op] = {dtype: count * times for (dtype, count) in dtypes.items()}
24+
return repeated
25+
26+
2027
# TODO Figure out how to handle multiple dq/q nodes properly
2128
# See backends/arm/_passes/decompose_quant_nodes.py for details
2229
dq_tosa_ops = {
@@ -35,7 +42,6 @@ def combine_op_dicts(*dicts):
3542
"CEIL": {"FP32": 1}, # for rounding
3643
"FLOOR": {"FP32": 1}, # for rounding
3744
}
38-
q_dq_tosa_ops = combine_op_dicts(dq_tosa_ops, q_tosa_ops)
3945

4046

4147
class AddSigmoidMul(torch.nn.Module):
@@ -61,7 +67,12 @@ def test_mixed_type_lowering():
6167
"ADD": {"INT32": 1}, # ADD should be executed in INT32
6268
"MUL": {"INT32": 1}, # MUL should be executed in INT32
6369
},
64-
q_dq_tosa_ops,
70+
repeat_op_dict(
71+
q_tosa_ops, 3
72+
), # Two decomposed boundary Q nodes + one for SIGMOID
73+
repeat_op_dict(
74+
dq_tosa_ops, 2
75+
), # One decomposed boundary DQ nodes + one for SIGMOID
6576
)
6677

6778
pipeline.add_stage_after(

backends/arm/test/tester/test_pipeline.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -390,14 +390,15 @@ def __init__(
390390
),
391391
}
392392
tosa_version = _require_tosa_version()
393+
tosa_spec: TosaSpecification = tosa_profiles[tosa_version]
393394

394395
compile_spec = common.get_tosa_compile_spec(
395-
tosa_profiles[tosa_version],
396+
tosa_spec,
396397
custom_path=custom_path,
397398
tosa_debug_mode=tosa_debug_mode,
398399
)
399400

400-
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
401+
quantizer = TOSAQuantizer(tosa_spec)
401402
# choose 16A8W quantization config when int16 extension is requested
402403
if "int16" in tosa_extensions:
403404
quantization_config = get_symmetric_a16w8_quantization_config(
@@ -422,7 +423,7 @@ def __init__(
422423
)
423424
self.add_stage(self.tester.quantize, quant_stage, pos=0)
424425

425-
remove_quant_nodes_stage = (
426+
remove_torch_quant_nodes_stage = (
426427
"to_edge_transform_and_lower"
427428
if use_to_edge_transform_and_lower
428429
else "partition"
@@ -440,7 +441,7 @@ def __init__(
440441
suffix="quant_nodes",
441442
)
442443
self.add_stage_after(
443-
remove_quant_nodes_stage,
444+
remove_torch_quant_nodes_stage,
444445
self.tester.check_not,
445446
[
446447
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
@@ -449,6 +450,21 @@ def __init__(
449450
suffix="quant_nodes",
450451
)
451452

453+
# For pure INT lowering, outer exir Q/DQ nodes remain in the graph because we can't partition them.
454+
# In INT+FP lowering, we partition these nodes, so a check is added to verify that.
455+
if tosa_spec.support_integer() and tosa_spec.support_float():
456+
self.add_stage_after(
457+
remove_torch_quant_nodes_stage,
458+
self.tester.check_not,
459+
[
460+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
461+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
462+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default",
463+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default",
464+
],
465+
suffix="exir_quant_nodes",
466+
)
467+
452468
if run_on_tosa_ref_model:
453469
self.add_stage(
454470
self.tester.run_method_and_compare_outputs,
@@ -1093,6 +1109,12 @@ def __init__(
10931109
transform_passes=transform_passes,
10941110
)
10951111

1112+
remove_torch_quant_nodes_stage = (
1113+
"to_edge_transform_and_lower"
1114+
if use_to_edge_transform_and_lower
1115+
else "partition"
1116+
)
1117+
10961118
if quantize:
10971119
quantizer = VgfQuantizer(compile_spec)
10981120
quantization_config = get_symmetric_quantization_config(
@@ -1104,12 +1126,6 @@ def __init__(
11041126

11051127
self.add_stage(self.tester.quantize, quant_stage, pos=0)
11061128

1107-
remove_quant_nodes_stage = (
1108-
"to_edge_transform_and_lower"
1109-
if use_to_edge_transform_and_lower
1110-
else "partition"
1111-
)
1112-
11131129
if _has_quantizable_inputs(test_data):
11141130
# only add stages if we have quantizable input
11151131
self.add_stage_after(
@@ -1122,7 +1138,7 @@ def __init__(
11221138
suffix="quant_nodes",
11231139
)
11241140
self.add_stage_after(
1125-
remove_quant_nodes_stage,
1141+
remove_torch_quant_nodes_stage,
11261142
self.tester.check_not,
11271143
[
11281144
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
@@ -1141,6 +1157,21 @@ def __init__(
11411157
suffix="quant_nodes",
11421158
)
11431159

1160+
# For pure INT lowering, outer exir Q/DQ nodes remain in the graph because we can't partition them.
1161+
# In INT+FP lowering, we partition these these nodes, so a check is added to verify that.
1162+
if tosa_spec.support_integer() and tosa_spec.support_float():
1163+
self.add_stage_after(
1164+
remove_torch_quant_nodes_stage,
1165+
self.tester.check_not,
1166+
[
1167+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
1168+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
1169+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default",
1170+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default",
1171+
],
1172+
suffix="exir_quant_nodes",
1173+
)
1174+
11441175
if run_on_vulkan_runtime:
11451176
self.add_stage(self.tester.serialize)
11461177
self.add_stage(

backends/arm/tosa/partitioner.py

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,56 @@ def __init__(
186186
self.additional_checks = additional_checks
187187
self.tosa_spec = compile_spec.tosa_spec
188188

189+
def _detag_boundary_nodes(
190+
self, module: GraphModule, tag: str, reporter: WhyNoPartitionReporter
191+
) -> None:
192+
"""De-tag nodes at the partition boundary.
193+
194+
Remove delegation tags from quantize nodes with inputs outside the
195+
partition and from dequantize nodes with outputs outside the partition.
196+
197+
For non Q/DQ nodes, remove the tag from the first node in the partition
198+
if any input has floating-point dtype.
199+
200+
Args:
201+
tag: The delegation tag assigned to the partition.
202+
reporter: A reporter to log rejected nodes.
203+
module: The GraphModule containing the partition.
204+
205+
"""
206+
207+
# De-tag outermost q-nodes upwards and dq-nodes downwards.
208+
# De-tag if at least one input/output is not part of the partition.
209+
for node in module.graph.nodes:
210+
if not is_partitioned(node, tag):
211+
continue
212+
213+
is_q_node = node.target in Q_OPS
214+
is_dq_node = node.target in DQ_OPS
215+
is_boundary_q_node = is_q_node and not is_partitioned(
216+
node.all_input_nodes[0], tag
217+
)
218+
is_boundary_dq_node = is_dq_node and any(
219+
not is_partitioned(user, tag) for user in node.users
220+
)
221+
222+
if is_boundary_q_node or is_boundary_dq_node:
223+
# Remove tag from quantize node with input outside partition,
224+
# or dequantize node with any output outside partition
225+
del node.meta["delegation_tag"]
226+
elif not is_q_node and not is_dq_node:
227+
# For non Q/DQ nodes, remove tag from first node in partition if any input has fp dtype
228+
for input in node.all_input_nodes:
229+
if is_partitioned(input, tag):
230+
continue
231+
if get_first_fake_tensor(input).dtype.is_floating_point:
232+
reporter.report_reject(
233+
node,
234+
f"Was first node in partition and input {input.name} had fp dtype.",
235+
)
236+
del node.meta["delegation_tag"]
237+
break
238+
189239
def _tag_module( # noqa
190240
self,
191241
module: GraphModule,
@@ -233,39 +283,13 @@ def _tag_module( # noqa
233283
for node in partition.nodes:
234284
node.meta["delegation_tag"] = tag
235285

236-
# De-tag outermost q-nodes upwards and dq-nodes downwards.
237-
# De-tag if at least one input/output is not part of the partition.
238-
for node in module.graph.nodes:
239-
if not is_partitioned(node, tag):
240-
continue
241-
if node.target in Q_OPS:
242-
for input in node.all_input_nodes:
243-
if not is_partitioned(input, tag):
244-
del node.meta["delegation_tag"]
245-
break
246-
continue
247-
248-
if node.target in DQ_OPS:
249-
for user in node.users:
250-
if not is_partitioned(user, tag):
251-
del node.meta["delegation_tag"]
252-
break
253-
continue
254-
255-
if self.tosa_spec.support_float():
256-
continue
257-
258-
if is_partitioned(node, tag):
259-
for input in node.all_input_nodes:
260-
if is_partitioned(input, tag):
261-
continue
262-
if get_first_fake_tensor(input).dtype.is_floating_point:
263-
reporter.report_reject(
264-
node,
265-
f"Was first node in partition and input {input.name} had fp dtype.",
266-
)
267-
del node.meta["delegation_tag"]
268-
break
286+
if self.tosa_spec.support_integer() and not self.tosa_spec.support_float():
287+
# Detag boundary Q/DQ since we cannot handle them without float support
288+
self._detag_boundary_nodes(
289+
module,
290+
tag,
291+
reporter,
292+
)
269293

270294
is_noop_partition = all(
271295
is_noop_clone(node)

0 commit comments

Comments
 (0)