Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from .remove_redundancy import RemoveRedundancy
from .replace_arange_args import ReplaceArangeArgs
from .replace_inf_values import ReplaceInfValues
from .resolve_debug_handle import ResolveDebugHandle
from .seq_mse import SeqMSE
from .tag_quant_io import TagQuantIO

Expand Down Expand Up @@ -94,6 +95,7 @@
RemoveRedundancy,
ReplaceArangeArgs,
ReplaceInfValues,
ResolveDebugHandle,
SeqMSE,
TagQuantIO,
]
4 changes: 3 additions & 1 deletion backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
RemoveRedundancy,
ReplaceArangeArgs,
ReplaceInfValues,
ResolveDebugHandle,
TagQuantIO,
)
from executorch.backends.qualcomm._passes.utils import (
Expand Down Expand Up @@ -105,6 +106,7 @@ def get_capture_program_passes():
(Remove0DTensor, True),
(RemoveRedundancy, True),
(TagQuantIO, False),
# ResolveDebugHandle will be added to last, check sorting below
]

passes = OrderedDict()
Expand Down Expand Up @@ -162,7 +164,6 @@ def get_to_edge_transform_passes(
for p in passes_job:
self.add_pass(p)
self.solve_constraints()

sorted_passes = self.passes
self.passes = []
for p in sorted_passes:
Expand All @@ -173,6 +174,7 @@ def get_to_edge_transform_passes(
if "edge_program" in kwargs:
kwargs["edge_program"] = exported_program
self.add_pass(p(**kwargs))
self.add_pass(ResolveDebugHandle())
return self.passes

def transform_for_to_edge_pipeline(
Expand Down
44 changes: 44 additions & 0 deletions backends/qualcomm/_passes/resolve_debug_handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import operator

import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DEBUG_HANDLE
from executorch.exir.pass_base import ExportPass, PassResult


class ResolveDebugHandle(ExportPass):
"""
Caution: This pass is executed as the last of the edge_passes.
For any passes executed during qnn_preprocess, users will need to handle debug_handle ID themselves.

Description: During passes transformation, some passes might be copying some node's meta when creating a new node,
which means multiple nodes might be sharing the same debug_handle ID while it shouldn't.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not super understand here: if several nodes comes from one acient node (e..g doing decomposition on some op), they should have the same debug handle for tracing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea is that if we decompose the node but never assign a new handle ID, we are just saving the information for the last decomposed node rather than all decomposed node. I have draw an example below. Since edge and QNN has 1 to 1 mapping in this case, I think it would be better to gather all possible information rather than the last node's debug info. Since we reassign graph_handle, instead of only getting the output of node2, we can also get info for node1.
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here im a little confused: when we see the qnn graph, how can we know that the qnn_node_1 and qnn_node2 here comes from a same super node? Or another q might be, which graph will play as the ground truth graph, when you doing intermediate comparsion?

gather all possible information rather than the last node's debug info.

We won't gather only the last node debug info, but all info.

In ExecuTorch normally we follow this rule:
if we transform {old_node_1, old_node_2, ..., old_node_n} into {new_node_1, new_node_2, ..., new_node_m}, where n and m can be arbitrary number starting from 1, then: eery new_node should have same debug handle, and the debug handle will be set(old_node_1.debug_handle + old_node_2.debug_handle, ..., old_node_n.debug_handle)

you can see if n is 1, this transform will be a operator decomposition; if m is 1, this transform will be a operator fusion, etc.

In this way whenever we see an arbitrary new_node, we will know its ancestor.

Not sure if that make sense to you?

This is critical as Intermediate Debugger uses debug handle as key.
debug_handle ID must be resolved so each op gets its own set of debug_handle ID and intermediate output.
"""

def __init__(self):
super(ResolveDebugHandle, self).__init__()

def call(self, graph_module: torch.fx.GraphModule):
handle_counter = 1
visited = set()
for node in graph_module.graph.nodes:
# Assume node is traversed in topological order, adding a check here to be safe.
if node.target == operator.getitem:
source_node = node.args[0]
assert (
source_node.name in visited
), "Graph is not traversed in topological order, unexpected behavior."
node.meta[QCOM_DEBUG_HANDLE] = source_node.meta[QCOM_DEBUG_HANDLE]
elif node.op == "call_function":
node.meta[QCOM_DEBUG_HANDLE] = handle_counter
handle_counter += 1
visited.add(node.name)

graph_module.recompile()
return PassResult(graph_module, True)
14 changes: 6 additions & 8 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
QCOM_BLOCK_SCALE_OFFSET,
QCOM_BLOCK_SCALES,
QCOM_BLOCK_STORAGE_TYPE,
QCOM_DEBUG_HANDLE,
QCOM_DTYPE,
QCOM_ENCODING,
QCOM_NUM_BLOCKS_PER_AXIS,
Expand All @@ -30,7 +31,6 @@
QCOM_SCALE,
QCOM_SCALE_OFFSET,
QCOM_SCALES,
QCOM_TENSOR_NAME,
QCOM_ZERO_POINT,
QCOM_ZERO_POINTS,
)
Expand Down Expand Up @@ -377,6 +377,11 @@ def get_tensor_name(
wrapper_idx: int = 0,
):
tensor_name = f"{node.name}_{wrapper_idx}"

# Only append special namings when enable tensor dump, since longer name results bigger .pte
if (handle_id := node.meta.get(QCOM_DEBUG_HANDLE)) and self.enable_tensor_dump:
tensor_name = f"{tensor_name}_debugID_{str(handle_id)}"

# The `input_{id}` is utilized for sorting at runtime. Due to multiple passes in qnn_preprocess,
# the input order between QNN and the original graph’s forward function may differ.
# The `mutbuf_{id}` is utilized for mapping I/O of mutable buffer at runtime.
Expand All @@ -397,12 +402,6 @@ def get_tensor_name(
elif is_graph_output(node):
tensor_name = f"output_{tensor_name}"

# Save this for intermediate debugger
# Needs idx since node like topk has 2 outputs
if QCOM_TENSOR_NAME in node.meta:
node.meta[QCOM_TENSOR_NAME][wrapper_idx] = tensor_name
else:
node.meta[QCOM_TENSOR_NAME] = {wrapper_idx: tensor_name}
return tensor_name

def define_custom_tensor_wrapper(
Expand Down Expand Up @@ -465,7 +464,6 @@ def define_tensor(

if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
return cached

tensor_name = self.get_tensor_name(tensor_source_node, wrapper_idx)
dims = torch.Size([1]) if len(tensor.size()) == 0 else tensor.size()
dynamic_dims, nominal_dims = self.get_dynamic_dimension(dims)
Expand Down
16 changes: 11 additions & 5 deletions backends/qualcomm/debugger/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ To make the implementation process smooth, we have also provided an example scri
During inference, there might be gaps between QNN and CPU final outputs. This leaves developers unsure about the root cause of accuracy drop. By using this debugger, users can gain better insight into which operation is causing the accuracy drop. Please note that the accuracy drop here refers to comparing QNN with CPU outputs, not the ground truth.

2. Who is this tool for?
This tool is mainly for developers aiming to align QNN with CPU accuracy. Users will be able to identify which layer in the model is causing the accuracy drop, helping them either circumvent the issue by replacing the layer with other operations or contact authors in Qualcomm AI Engine Direct to resolve the accuracy issue. Please refer to the last section under [README.md](../README.md) for authors to contact when encountering any issues.
This tool is mainly for developers aiming to align QNN with CPU accuracy. Users will be able to identify which operation(s) in the model is causing the accuracy drop, helping them either circumvent the issue by:
1. Changing the quant specs for particular operation(s).
2. Replacing the operation(s) with other operation(s)
3. Contact authors in Qualcomm AI Engine Direct to resolve the accuracy issue. Please refer to the last section under [README.md](../README.md) for authors to contact when encountering any issues.


## Design Flow
Expand All @@ -119,8 +122,8 @@ flowchart TB;
edge_program --> qnn_lower["QNN with Per-Layer Dump"];
qnn_lower --> qnn_inference[QNN Inference];
qnn_inference --> debug
edge_program --> cpu_lower["Edge CPU with Per-Layer Dump"];
cpu_lower --> cpu_inference["CPU Inference"];
edge_program --> IntermediateOutputCapturer;
IntermediateOutputCapturer --> cpu_inference["CPU Inference"];
cpu_inference --> debug["Debug"];
debug --> output["Output Results"]
```
Expand Down Expand Up @@ -172,6 +175,8 @@ adb = SimpleADB(
shared_buffer=args.shared_buffer,
dump_intermediate_outputs=args.dump_intermediate_outputs, # Add this flag
)
adb.push(inputs=single_set_of_input)
adb.execute()
```

### 6: Pull and process the results.
Expand All @@ -185,7 +190,7 @@ def validate_intermediate_tensor():
etdump_path=f"{args.artifact}/etdump.etdp",
debug_buffer_path=f"{args.artifact}/debug_output.bin",
)
qnn_intermediate_debugger.intermediate_output_module(*(inputs[0]))
qnn_intermediate_debugger.capture_golden(single_set_of_input)
qnn_intermediate_debugger.generate_results(
title="debug_graph",
path=".",
Expand Down Expand Up @@ -244,8 +249,9 @@ To execute the model:
python examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py -b build-android -m ${SOC_MODEL} --device ${SERIAL_NUM} --dataset ${PATH_TO_DATASET} --dump_intermediate_outputs
```

### Limitation
### Limitations
1. The current debugger only supports performing one execution. Multiple executions may cause unknown behavior and are not recommended.
2. Please ignore this if you are using `qnn_executor_runner`. If you have decided to write your own runner, please follow the [tutorial](https://pytorch.org/executorch/stable/etdump.html) on how to implement etdump into your own runner.
3. The current debugger does not support graph with partitions. (WIP)
4. The current debugger does not support LLM models. (WIP)
5. Graph with multimethod. (WIP)
44 changes: 19 additions & 25 deletions backends/qualcomm/debugger/format_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
import pydot
import torch
from executorch.backends.qualcomm.utils.constants import (
QCOM_DEBUG_HANDLE,
QCOM_QUANT_ATTRS,
QCOM_SCALE,
QCOM_SCALES,
QCOM_TENSOR_NAME,
QCOM_ZERO_POINT,
QCOM_ZERO_POINTS,
)
Expand Down Expand Up @@ -46,6 +46,8 @@ def retrieve_node_info(evaluator, node, node_tensor_map):
node_info["op_code"] = node.op
node_info["target"] = typename(node.target)
node_info["num_users"] = len(node.users)
# Only call_function and getitem nodes that is present prior to qnn_preprocess has a debug_handle.
node_info["debug_handle"] = node.meta.get(QCOM_DEBUG_HANDLE, -1)

if "val" in node.meta:
if isinstance(node.meta["val"], torch.Tensor):
Expand All @@ -67,24 +69,22 @@ def retrieve_node_info(evaluator, node, node_tensor_map):
if QCOM_ZERO_POINTS in quant_attrs
else quant_attrs.get(QCOM_ZERO_POINT)
)

if node.name in node_tensor_map:
qnn_output, cpu_output, meta = node_tensor_map[node.name]
node_info[QCOM_TENSOR_NAME] = meta.get(QCOM_TENSOR_NAME)
if node_data := node_tensor_map.get(node.name):
qnn_output, cpu_output, debug_meta = node_data
assert debug_meta.edge_node_name == node.name
node_info[evaluator.metric_name()], node_info["is_valid_score"] = (
evaluator.evaluate(qnn_output, cpu_output)
)

# The values in meta are directly retrieved from the node during the forward hook, which means the values should be the same for meta and node.meta.
# Storing these data during the forward hook helps us compare QNN tensors with CPU tensors without traversing the graph.
# We only check "scale" and not "scales" since the forward hook only stores the node's output, which should always be per tensor.
if QCOM_QUANT_ATTRS in node.meta:
assert (
node_info["scale(s)"] == node.meta[QCOM_QUANT_ATTRS][QCOM_SCALE]
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
assert node_info["scale(s)"] == quant_attrs.get(
QCOM_SCALE
), "node meta scale should be same as scale retrieve during forward hook"
assert (
node_info["zero_point(s)"]
== node.meta[QCOM_QUANT_ATTRS][QCOM_ZERO_POINT]
assert node_info["zero_point(s)"] == quant_attrs.get(
QCOM_ZERO_POINT
), "node meta zero_point should be same as zero_point retrieve during forward hook"

return node_info
Expand Down Expand Up @@ -128,7 +128,7 @@ def get_node_style(is_valid_score: bool):
node_label = "{"
node_label += f"name=%{node_info.get('name')}" + r"\n"
node_label += f"|op_code={node_info.get('op_code')}" + r"\n"
node_label += f"|qnn_tensor_name={node_info.get('qnn_tensor_name')}" + r"\n"
node_label += f"|debug_handle={node_info.get('debug_handle')}" + r"\n"
node_label += f"|target={node_info.get('target')}" + r"\n"
node_label += f"|num_users={node_info.get('num_users')}" + r"\n"
node_label += f"|pytorch_layout={node_info.get('pytorch_layout')}" + r"\n"
Expand Down Expand Up @@ -177,14 +177,15 @@ def export_csv(
node_info = retrieve_node_info(
evaluator=evaluator, node=node, node_tensor_map=node_tensor_map
)

node_info_list.append(node_info)

# Writing to a CSV file
with open(f"{path}/{title}.csv", mode="w", newline="") as csv_file:
fieldnames = [
"name",
"op_code",
"qnn_tensor_name",
"debug_handle",
"target",
"num_users",
"pytorch_layout",
Expand All @@ -203,19 +204,12 @@ def export_csv(

def export_raw(
path: str,
edge_module: torch.fx.GraphModule,
node_tensor_map: dict,
):
for node in edge_module.graph.nodes:
# These are just unused nodes before fold_quant and still there
if len(node.users) == 0 and node.op == "placeholder":
continue
if paired_event := node_tensor_map.get(node.name):
qnn_output, cpu_output, meta = paired_event
qnn_tensor_name = meta[QCOM_TENSOR_NAME]
qnn_output_path = os.path.join(path, qnn_tensor_name + "_qnn.raw")
cpu_output_path = os.path.join(path, qnn_tensor_name + "_cpu.raw")
qnn_output.numpy().tofile(qnn_output_path)
cpu_output.numpy().tofile(cpu_output_path)
for qnn_output, cpu_output, meta in node_tensor_map.values():
qnn_output_path = os.path.join(path, meta.edge_node_name + "_qnn.raw")
cpu_output_path = os.path.join(path, meta.edge_node_name + "_cpu.raw")
qnn_output.numpy().tofile(qnn_output_path)
cpu_output.numpy().tofile(cpu_output_path)

print(f"Intermediate debugger raw files saved at: {path}")
Loading
Loading