Skip to content

Commit 00c4f7e

Browse files
committed
Qualcomm AI Engine Direct - Support Debug Handle and Integrate IntermediateOutputCapturer
1 parent 3233761 commit 00c4f7e

File tree

17 files changed

+366
-257
lines changed

17 files changed

+366
-257
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from .remove_redundancy import RemoveRedundancy
4848
from .replace_arange_args import ReplaceArangeArgs
4949
from .replace_inf_values import ReplaceInfValues
50+
from .resolve_debug_handle import ResolveDebugHandle
5051
from .seq_mse import SeqMSE
5152
from .tag_quant_io import TagQuantIO
5253

@@ -94,6 +95,7 @@
9495
RemoveRedundancy,
9596
ReplaceArangeArgs,
9697
ReplaceInfValues,
98+
ResolveDebugHandle,
9799
SeqMSE,
98100
TagQuantIO,
99101
]

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
RemoveRedundancy,
5353
ReplaceArangeArgs,
5454
ReplaceInfValues,
55+
ResolveDebugHandle,
5556
TagQuantIO,
5657
)
5758
from executorch.backends.qualcomm._passes.utils import (
@@ -105,6 +106,7 @@ def get_capture_program_passes():
105106
(Remove0DTensor, True),
106107
(RemoveRedundancy, True),
107108
(TagQuantIO, False),
109+
# ResolveDebugHandle will be added to last, check sorting below
108110
]
109111

110112
passes = OrderedDict()
@@ -162,7 +164,6 @@ def get_to_edge_transform_passes(
162164
for p in passes_job:
163165
self.add_pass(p)
164166
self.solve_constraints()
165-
166167
sorted_passes = self.passes
167168
self.passes = []
168169
for p in sorted_passes:
@@ -173,6 +174,7 @@ def get_to_edge_transform_passes(
173174
if "edge_program" in kwargs:
174175
kwargs["edge_program"] = exported_program
175176
self.add_pass(p(**kwargs))
177+
self.add_pass(ResolveDebugHandle())
176178
return self.passes
177179

178180
def transform_for_to_edge_pipeline(
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import operator
7+
8+
import torch
9+
from executorch.backends.qualcomm.utils.constants import QCOM_DEBUG_HANDLE
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
12+
13+
class ResolveDebugHandle(ExportPass):
14+
"""
15+
Caution: This pass is executed as the last of the edge_passes.
16+
For any passes executed during qnn_preprocess, users will need to handle debug_handle ID themselves.
17+
18+
Description: During passes transformation, some passes might be copying some node's meta when creating a new node,
19+
which means multiple nodes might be sharing the same debug_handle ID while it shouldn't.
20+
This is critical as Intermediate Debugger uses debug handle as key.
21+
debug_handle ID must be resolved so each op gets its own set of debug_handle ID and intermediate output.
22+
"""
23+
24+
def __init__(self):
25+
super(ResolveDebugHandle, self).__init__()
26+
27+
def call(self, graph_module: torch.fx.GraphModule):
28+
handle_counter = 1
29+
visited = set()
30+
for node in graph_module.graph.nodes:
31+
# Assume node is traversed in topological order, adding a check here to be safe.
32+
if node.target == operator.getitem:
33+
source_node = node.args[0]
34+
assert (
35+
source_node.name in visited
36+
), "Graph is not traversed in topological order, unexpected behavior."
37+
node.meta[QCOM_DEBUG_HANDLE] = source_node.meta[QCOM_DEBUG_HANDLE]
38+
elif node.op == "call_function":
39+
node.meta[QCOM_DEBUG_HANDLE] = handle_counter
40+
handle_counter += 1
41+
visited.add(node.name)
42+
43+
graph_module.recompile()
44+
return PassResult(graph_module, True)

backends/qualcomm/builders/node_visitor.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
QCOM_BLOCK_SCALE_OFFSET,
2020
QCOM_BLOCK_SCALES,
2121
QCOM_BLOCK_STORAGE_TYPE,
22+
QCOM_DEBUG_HANDLE,
2223
QCOM_DTYPE,
2324
QCOM_ENCODING,
2425
QCOM_NUM_BLOCKS_PER_AXIS,
@@ -30,7 +31,6 @@
3031
QCOM_SCALE,
3132
QCOM_SCALE_OFFSET,
3233
QCOM_SCALES,
33-
QCOM_TENSOR_NAME,
3434
QCOM_ZERO_POINT,
3535
QCOM_ZERO_POINTS,
3636
)
@@ -377,6 +377,11 @@ def get_tensor_name(
377377
wrapper_idx: int = 0,
378378
):
379379
tensor_name = f"{node.name}_{wrapper_idx}"
380+
381+
# Only append special namings when enable tensor dump, since longer name results bigger .pte
382+
if (handle_id := node.meta.get(QCOM_DEBUG_HANDLE)) and self.enable_tensor_dump:
383+
tensor_name = f"{tensor_name}_debugID_{str(handle_id)}"
384+
380385
# The `input_{id}` is utilized for sorting at runtime. Due to multiple passes in qnn_preprocess,
381386
# the input order between QNN and the original graph’s forward function may differ.
382387
# The `mutbuf_{id}` is utilized for mapping I/O of mutable buffer at runtime.
@@ -397,12 +402,6 @@ def get_tensor_name(
397402
elif is_graph_output(node):
398403
tensor_name = f"output_{tensor_name}"
399404

400-
# Save this for intermediate debugger
401-
# Needs idx since node like topk has 2 outputs
402-
if QCOM_TENSOR_NAME in node.meta:
403-
node.meta[QCOM_TENSOR_NAME][wrapper_idx] = tensor_name
404-
else:
405-
node.meta[QCOM_TENSOR_NAME] = {wrapper_idx: tensor_name}
406405
return tensor_name
407406

408407
def define_custom_tensor_wrapper(
@@ -465,7 +464,6 @@ def define_tensor(
465464

466465
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
467466
return cached
468-
469467
tensor_name = self.get_tensor_name(tensor_source_node, wrapper_idx)
470468
dims = torch.Size([1]) if len(tensor.size()) == 0 else tensor.size()
471469
dynamic_dims, nominal_dims = self.get_dynamic_dimension(dims)

backends/qualcomm/debugger/README.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,10 @@ To make the implementation process smooth, we have also provided an example scri
108108
During inference, there might be gaps between QNN and CPU final outputs. This leaves developers unsure about the root cause of accuracy drop. By using this debugger, users can gain better insight into which operation is causing the accuracy drop. Please note that the accuracy drop here refers to comparing QNN with CPU outputs, not the ground truth.
109109

110110
2. Who is this tool for?
111-
This tool is mainly for developers aiming to align QNN with CPU accuracy. Users will be able to identify which layer in the model is causing the accuracy drop, helping them either circumvent the issue by replacing the layer with other operations or contact authors in Qualcomm AI Engine Direct to resolve the accuracy issue. Please refer to the last section under [README.md](../README.md) for authors to contact when encountering any issues.
111+
This tool is mainly for developers aiming to align QNN with CPU accuracy. Users will be able to identify which operation(s) in the model is causing the accuracy drop, helping them either circumvent the issue by:
112+
1. Changing the quant specs for particular operation(s).
113+
2. Replacing the operation(s) with other operation(s)
114+
3. Contact authors in Qualcomm AI Engine Direct to resolve the accuracy issue. Please refer to the last section under [README.md](../README.md) for authors to contact when encountering any issues.
112115

113116

114117
## Design Flow
@@ -119,8 +122,8 @@ flowchart TB;
119122
edge_program --> qnn_lower["QNN with Per-Layer Dump"];
120123
qnn_lower --> qnn_inference[QNN Inference];
121124
qnn_inference --> debug
122-
edge_program --> cpu_lower["Edge CPU with Per-Layer Dump"];
123-
cpu_lower --> cpu_inference["CPU Inference"];
125+
edge_program --> IntermediateOutputCapturer;
126+
IntermediateOutputCapturer --> cpu_inference["CPU Inference"];
124127
cpu_inference --> debug["Debug"];
125128
debug --> output["Output Results"]
126129
```
@@ -172,6 +175,8 @@ adb = SimpleADB(
172175
shared_buffer=args.shared_buffer,
173176
dump_intermediate_outputs=args.dump_intermediate_outputs, # Add this flag
174177
)
178+
adb.push(inputs=single_set_of_input)
179+
adb.execute()
175180
```
176181

177182
### 6: Pull and process the results.
@@ -185,7 +190,7 @@ def validate_intermediate_tensor():
185190
etdump_path=f"{args.artifact}/etdump.etdp",
186191
debug_buffer_path=f"{args.artifact}/debug_output.bin",
187192
)
188-
qnn_intermediate_debugger.intermediate_output_module(*(inputs[0]))
193+
qnn_intermediate_debugger.capture_golden(single_set_of_input)
189194
qnn_intermediate_debugger.generate_results(
190195
title="debug_graph",
191196
path=".",
@@ -244,8 +249,9 @@ To execute the model:
244249
python examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py -b build-android -m ${SOC_MODEL} --device ${SERIAL_NUM} --dataset ${PATH_TO_DATASET} --dump_intermediate_outputs
245250
```
246251

247-
### Limitation
252+
### Limitations
248253
1. The current debugger only supports performing one execution. Multiple executions may cause unknown behavior and are not recommended.
249254
2. Please ignore this if you are using `qnn_executor_runner`. If you have decided to write your own runner, please follow the [tutorial](https://pytorch.org/executorch/stable/etdump.html) on how to implement etdump into your own runner.
250255
3. The current debugger does not support graph with partitions. (WIP)
251256
4. The current debugger does not support LLM models. (WIP)
257+
5. Graph with multimethod. (WIP)

backends/qualcomm/debugger/format_outputs.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
import pydot
1212
import torch
1313
from executorch.backends.qualcomm.utils.constants import (
14+
QCOM_DEBUG_HANDLE,
1415
QCOM_QUANT_ATTRS,
1516
QCOM_SCALE,
1617
QCOM_SCALES,
17-
QCOM_TENSOR_NAME,
1818
QCOM_ZERO_POINT,
1919
QCOM_ZERO_POINTS,
2020
)
@@ -46,6 +46,8 @@ def retrieve_node_info(evaluator, node, node_tensor_map):
4646
node_info["op_code"] = node.op
4747
node_info["target"] = typename(node.target)
4848
node_info["num_users"] = len(node.users)
49+
# Only call_function and getitem nodes that is present prior to qnn_preprocess has a debug_handle.
50+
node_info["debug_handle"] = node.meta.get(QCOM_DEBUG_HANDLE, -1)
4951

5052
if "val" in node.meta:
5153
if isinstance(node.meta["val"], torch.Tensor):
@@ -67,24 +69,22 @@ def retrieve_node_info(evaluator, node, node_tensor_map):
6769
if QCOM_ZERO_POINTS in quant_attrs
6870
else quant_attrs.get(QCOM_ZERO_POINT)
6971
)
70-
71-
if node.name in node_tensor_map:
72-
qnn_output, cpu_output, meta = node_tensor_map[node.name]
73-
node_info[QCOM_TENSOR_NAME] = meta.get(QCOM_TENSOR_NAME)
72+
if node_data := node_tensor_map.get(node.name):
73+
qnn_output, cpu_output, debug_meta = node_data
74+
assert debug_meta.edge_node_name == node.name
7475
node_info[evaluator.metric_name()], node_info["is_valid_score"] = (
7576
evaluator.evaluate(qnn_output, cpu_output)
7677
)
7778

7879
# The values in meta are directly retrieved from the node during the forward hook, which means the values should be the same for meta and node.meta.
7980
# Storing these data during the forward hook helps us compare QNN tensors with CPU tensors without traversing the graph.
8081
# We only check "scale" and not "scales" since the forward hook only stores the node's output, which should always be per tensor.
81-
if QCOM_QUANT_ATTRS in node.meta:
82-
assert (
83-
node_info["scale(s)"] == node.meta[QCOM_QUANT_ATTRS][QCOM_SCALE]
82+
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
83+
assert node_info["scale(s)"] == quant_attrs.get(
84+
QCOM_SCALE
8485
), "node meta scale should be same as scale retrieve during forward hook"
85-
assert (
86-
node_info["zero_point(s)"]
87-
== node.meta[QCOM_QUANT_ATTRS][QCOM_ZERO_POINT]
86+
assert node_info["zero_point(s)"] == quant_attrs.get(
87+
QCOM_ZERO_POINT
8888
), "node meta zero_point should be same as zero_point retrieve during forward hook"
8989

9090
return node_info
@@ -128,7 +128,7 @@ def get_node_style(is_valid_score: bool):
128128
node_label = "{"
129129
node_label += f"name=%{node_info.get('name')}" + r"\n"
130130
node_label += f"|op_code={node_info.get('op_code')}" + r"\n"
131-
node_label += f"|qnn_tensor_name={node_info.get('qnn_tensor_name')}" + r"\n"
131+
node_label += f"|debug_handle={node_info.get('debug_handle')}" + r"\n"
132132
node_label += f"|target={node_info.get('target')}" + r"\n"
133133
node_label += f"|num_users={node_info.get('num_users')}" + r"\n"
134134
node_label += f"|pytorch_layout={node_info.get('pytorch_layout')}" + r"\n"
@@ -177,14 +177,15 @@ def export_csv(
177177
node_info = retrieve_node_info(
178178
evaluator=evaluator, node=node, node_tensor_map=node_tensor_map
179179
)
180+
180181
node_info_list.append(node_info)
181182

182183
# Writing to a CSV file
183184
with open(f"{path}/{title}.csv", mode="w", newline="") as csv_file:
184185
fieldnames = [
185186
"name",
186187
"op_code",
187-
"qnn_tensor_name",
188+
"debug_handle",
188189
"target",
189190
"num_users",
190191
"pytorch_layout",
@@ -203,19 +204,12 @@ def export_csv(
203204

204205
def export_raw(
205206
path: str,
206-
edge_module: torch.fx.GraphModule,
207207
node_tensor_map: dict,
208208
):
209-
for node in edge_module.graph.nodes:
210-
# These are just unused nodes before fold_quant and still there
211-
if len(node.users) == 0 and node.op == "placeholder":
212-
continue
213-
if paired_event := node_tensor_map.get(node.name):
214-
qnn_output, cpu_output, meta = paired_event
215-
qnn_tensor_name = meta[QCOM_TENSOR_NAME]
216-
qnn_output_path = os.path.join(path, qnn_tensor_name + "_qnn.raw")
217-
cpu_output_path = os.path.join(path, qnn_tensor_name + "_cpu.raw")
218-
qnn_output.numpy().tofile(qnn_output_path)
219-
cpu_output.numpy().tofile(cpu_output_path)
209+
for qnn_output, cpu_output, meta in node_tensor_map.values():
210+
qnn_output_path = os.path.join(path, meta.edge_node_name + "_qnn.raw")
211+
cpu_output_path = os.path.join(path, meta.edge_node_name + "_cpu.raw")
212+
qnn_output.numpy().tofile(qnn_output_path)
213+
cpu_output.numpy().tofile(cpu_output_path)
220214

221215
print(f"Intermediate debugger raw files saved at: {path}")

0 commit comments

Comments
 (0)