diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 49449fe2190..e47938ef2a1 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -47,6 +47,7 @@ from .remove_redundancy import RemoveRedundancy from .replace_arange_args import ReplaceArangeArgs from .replace_inf_values import ReplaceInfValues +from .resolve_debug_handle import ResolveDebugHandle from .seq_mse import SeqMSE from .tag_quant_io import TagQuantIO @@ -94,6 +95,7 @@ RemoveRedundancy, ReplaceArangeArgs, ReplaceInfValues, + ResolveDebugHandle, SeqMSE, TagQuantIO, ] diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 46a1dfb0970..96394259397 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -52,6 +52,7 @@ RemoveRedundancy, ReplaceArangeArgs, ReplaceInfValues, + ResolveDebugHandle, TagQuantIO, ) from executorch.backends.qualcomm._passes.utils import ( @@ -105,6 +106,7 @@ def get_capture_program_passes(): (Remove0DTensor, True), (RemoveRedundancy, True), (TagQuantIO, False), + # ResolveDebugHandle will be added to last, check sorting below ] passes = OrderedDict() @@ -162,7 +164,6 @@ def get_to_edge_transform_passes( for p in passes_job: self.add_pass(p) self.solve_constraints() - sorted_passes = self.passes self.passes = [] for p in sorted_passes: @@ -173,6 +174,7 @@ def get_to_edge_transform_passes( if "edge_program" in kwargs: kwargs["edge_program"] = exported_program self.add_pass(p(**kwargs)) + self.add_pass(ResolveDebugHandle()) return self.passes def transform_for_to_edge_pipeline( diff --git a/backends/qualcomm/_passes/resolve_debug_handle.py b/backends/qualcomm/_passes/resolve_debug_handle.py new file mode 100644 index 00000000000..1168d553a0a --- /dev/null +++ b/backends/qualcomm/_passes/resolve_debug_handle.py @@ -0,0 +1,44 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import operator + +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DEBUG_HANDLE +from executorch.exir.pass_base import ExportPass, PassResult + + +class ResolveDebugHandle(ExportPass): + """ + Caution: This pass is executed as the last of the edge_passes. + For any passes executed during qnn_preprocess, users will need to handle debug_handle ID themselves. + + Description: During passes transformation, some passes might be copying some node's meta when creating a new node, + which means multiple nodes might be sharing the same debug_handle ID while it shouldn't. + This is critical as Intermediate Debugger uses debug handle as key. + debug_handle ID must be resolved so each op gets its own set of debug_handle ID and intermediate output. + """ + + def __init__(self): + super(ResolveDebugHandle, self).__init__() + + def call(self, graph_module: torch.fx.GraphModule): + handle_counter = 1 + visited = set() + for node in graph_module.graph.nodes: + # Assume node is traversed in topological order, adding a check here to be safe. + if node.target == operator.getitem: + source_node = node.args[0] + assert ( + source_node.name in visited + ), "Graph is not traversed in topological order, unexpected behavior." + node.meta[QCOM_DEBUG_HANDLE] = source_node.meta[QCOM_DEBUG_HANDLE] + elif node.op == "call_function": + node.meta[QCOM_DEBUG_HANDLE] = handle_counter + handle_counter += 1 + visited.add(node.name) + + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index f2fcf65c896..1697f9288bf 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -19,6 +19,7 @@ QCOM_BLOCK_SCALE_OFFSET, QCOM_BLOCK_SCALES, QCOM_BLOCK_STORAGE_TYPE, + QCOM_DEBUG_HANDLE, QCOM_DTYPE, QCOM_ENCODING, QCOM_NUM_BLOCKS_PER_AXIS, @@ -30,7 +31,6 @@ QCOM_SCALE, QCOM_SCALE_OFFSET, QCOM_SCALES, - QCOM_TENSOR_NAME, QCOM_ZERO_POINT, QCOM_ZERO_POINTS, ) @@ -377,6 +377,11 @@ def get_tensor_name( wrapper_idx: int = 0, ): tensor_name = f"{node.name}_{wrapper_idx}" + + # Only append special namings when enable tensor dump, since longer name results bigger .pte + if (handle_id := node.meta.get(QCOM_DEBUG_HANDLE)) and self.enable_tensor_dump: + tensor_name = f"{tensor_name}_debugID_{str(handle_id)}" + # The `input_{id}` is utilized for sorting at runtime. Due to multiple passes in qnn_preprocess, # the input order between QNN and the original graph’s forward function may differ. # The `mutbuf_{id}` is utilized for mapping I/O of mutable buffer at runtime. @@ -397,12 +402,6 @@ def get_tensor_name( elif is_graph_output(node): tensor_name = f"output_{tensor_name}" - # Save this for intermediate debugger - # Needs idx since node like topk has 2 outputs - if QCOM_TENSOR_NAME in node.meta: - node.meta[QCOM_TENSOR_NAME][wrapper_idx] = tensor_name - else: - node.meta[QCOM_TENSOR_NAME] = {wrapper_idx: tensor_name} return tensor_name def define_custom_tensor_wrapper( @@ -465,7 +464,6 @@ def define_tensor( if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None): return cached - tensor_name = self.get_tensor_name(tensor_source_node, wrapper_idx) dims = torch.Size([1]) if len(tensor.size()) == 0 else tensor.size() dynamic_dims, nominal_dims = self.get_dynamic_dimension(dims) diff --git a/backends/qualcomm/debugger/README.md b/backends/qualcomm/debugger/README.md index 1c91382131f..1aafd0d1e8a 100644 --- a/backends/qualcomm/debugger/README.md +++ b/backends/qualcomm/debugger/README.md @@ -108,7 +108,10 @@ To make the implementation process smooth, we have also provided an example scri During inference, there might be gaps between QNN and CPU final outputs. This leaves developers unsure about the root cause of accuracy drop. By using this debugger, users can gain better insight into which operation is causing the accuracy drop. Please note that the accuracy drop here refers to comparing QNN with CPU outputs, not the ground truth. 2. Who is this tool for? - This tool is mainly for developers aiming to align QNN with CPU accuracy. Users will be able to identify which layer in the model is causing the accuracy drop, helping them either circumvent the issue by replacing the layer with other operations or contact authors in Qualcomm AI Engine Direct to resolve the accuracy issue. Please refer to the last section under [README.md](../README.md) for authors to contact when encountering any issues. + This tool is mainly for developers aiming to align QNN with CPU accuracy. Users will be able to identify which operation(s) in the model is causing the accuracy drop, helping them either circumvent the issue by: + 1. Changing the quant specs for particular operation(s). + 2. Replacing the operation(s) with other operation(s) + 3. Contact authors in Qualcomm AI Engine Direct to resolve the accuracy issue. Please refer to the last section under [README.md](../README.md) for authors to contact when encountering any issues. ## Design Flow @@ -119,8 +122,8 @@ flowchart TB; edge_program --> qnn_lower["QNN with Per-Layer Dump"]; qnn_lower --> qnn_inference[QNN Inference]; qnn_inference --> debug - edge_program --> cpu_lower["Edge CPU with Per-Layer Dump"]; - cpu_lower --> cpu_inference["CPU Inference"]; + edge_program --> IntermediateOutputCapturer; + IntermediateOutputCapturer --> cpu_inference["CPU Inference"]; cpu_inference --> debug["Debug"]; debug --> output["Output Results"] ``` @@ -172,6 +175,8 @@ adb = SimpleADB( shared_buffer=args.shared_buffer, dump_intermediate_outputs=args.dump_intermediate_outputs, # Add this flag ) +adb.push(inputs=single_set_of_input) +adb.execute() ``` ### 6: Pull and process the results. @@ -185,7 +190,7 @@ def validate_intermediate_tensor(): etdump_path=f"{args.artifact}/etdump.etdp", debug_buffer_path=f"{args.artifact}/debug_output.bin", ) - qnn_intermediate_debugger.intermediate_output_module(*(inputs[0])) + qnn_intermediate_debugger.capture_golden(single_set_of_input) qnn_intermediate_debugger.generate_results( title="debug_graph", path=".", @@ -244,8 +249,9 @@ To execute the model: python examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py -b build-android -m ${SOC_MODEL} --device ${SERIAL_NUM} --dataset ${PATH_TO_DATASET} --dump_intermediate_outputs ``` -### Limitation +### Limitations 1. The current debugger only supports performing one execution. Multiple executions may cause unknown behavior and are not recommended. 2. Please ignore this if you are using `qnn_executor_runner`. If you have decided to write your own runner, please follow the [tutorial](https://pytorch.org/executorch/stable/etdump.html) on how to implement etdump into your own runner. 3. The current debugger does not support graph with partitions. (WIP) 4. The current debugger does not support LLM models. (WIP) +5. Graph with multimethod. (WIP) diff --git a/backends/qualcomm/debugger/format_outputs.py b/backends/qualcomm/debugger/format_outputs.py index 05f5c908919..f210634210d 100644 --- a/backends/qualcomm/debugger/format_outputs.py +++ b/backends/qualcomm/debugger/format_outputs.py @@ -11,10 +11,10 @@ import pydot import torch from executorch.backends.qualcomm.utils.constants import ( + QCOM_DEBUG_HANDLE, QCOM_QUANT_ATTRS, QCOM_SCALE, QCOM_SCALES, - QCOM_TENSOR_NAME, QCOM_ZERO_POINT, QCOM_ZERO_POINTS, ) @@ -46,6 +46,8 @@ def retrieve_node_info(evaluator, node, node_tensor_map): node_info["op_code"] = node.op node_info["target"] = typename(node.target) node_info["num_users"] = len(node.users) + # Only call_function and getitem nodes that is present prior to qnn_preprocess has a debug_handle. + node_info["debug_handle"] = node.meta.get(QCOM_DEBUG_HANDLE, -1) if "val" in node.meta: if isinstance(node.meta["val"], torch.Tensor): @@ -67,10 +69,9 @@ def retrieve_node_info(evaluator, node, node_tensor_map): if QCOM_ZERO_POINTS in quant_attrs else quant_attrs.get(QCOM_ZERO_POINT) ) - - if node.name in node_tensor_map: - qnn_output, cpu_output, meta = node_tensor_map[node.name] - node_info[QCOM_TENSOR_NAME] = meta.get(QCOM_TENSOR_NAME) + if node_data := node_tensor_map.get(node.name): + qnn_output, cpu_output, debug_meta = node_data + assert debug_meta.edge_node_name == node.name node_info[evaluator.metric_name()], node_info["is_valid_score"] = ( evaluator.evaluate(qnn_output, cpu_output) ) @@ -78,13 +79,12 @@ def retrieve_node_info(evaluator, node, node_tensor_map): # The values in meta are directly retrieved from the node during the forward hook, which means the values should be the same for meta and node.meta. # Storing these data during the forward hook helps us compare QNN tensors with CPU tensors without traversing the graph. # We only check "scale" and not "scales" since the forward hook only stores the node's output, which should always be per tensor. - if QCOM_QUANT_ATTRS in node.meta: - assert ( - node_info["scale(s)"] == node.meta[QCOM_QUANT_ATTRS][QCOM_SCALE] + if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): + assert node_info["scale(s)"] == quant_attrs.get( + QCOM_SCALE ), "node meta scale should be same as scale retrieve during forward hook" - assert ( - node_info["zero_point(s)"] - == node.meta[QCOM_QUANT_ATTRS][QCOM_ZERO_POINT] + assert node_info["zero_point(s)"] == quant_attrs.get( + QCOM_ZERO_POINT ), "node meta zero_point should be same as zero_point retrieve during forward hook" return node_info @@ -128,7 +128,7 @@ def get_node_style(is_valid_score: bool): node_label = "{" node_label += f"name=%{node_info.get('name')}" + r"\n" node_label += f"|op_code={node_info.get('op_code')}" + r"\n" - node_label += f"|qnn_tensor_name={node_info.get('qnn_tensor_name')}" + r"\n" + node_label += f"|debug_handle={node_info.get('debug_handle')}" + r"\n" node_label += f"|target={node_info.get('target')}" + r"\n" node_label += f"|num_users={node_info.get('num_users')}" + r"\n" node_label += f"|pytorch_layout={node_info.get('pytorch_layout')}" + r"\n" @@ -177,6 +177,7 @@ def export_csv( node_info = retrieve_node_info( evaluator=evaluator, node=node, node_tensor_map=node_tensor_map ) + node_info_list.append(node_info) # Writing to a CSV file @@ -184,7 +185,7 @@ def export_csv( fieldnames = [ "name", "op_code", - "qnn_tensor_name", + "debug_handle", "target", "num_users", "pytorch_layout", @@ -203,19 +204,12 @@ def export_csv( def export_raw( path: str, - edge_module: torch.fx.GraphModule, node_tensor_map: dict, ): - for node in edge_module.graph.nodes: - # These are just unused nodes before fold_quant and still there - if len(node.users) == 0 and node.op == "placeholder": - continue - if paired_event := node_tensor_map.get(node.name): - qnn_output, cpu_output, meta = paired_event - qnn_tensor_name = meta[QCOM_TENSOR_NAME] - qnn_output_path = os.path.join(path, qnn_tensor_name + "_qnn.raw") - cpu_output_path = os.path.join(path, qnn_tensor_name + "_cpu.raw") - qnn_output.numpy().tofile(qnn_output_path) - cpu_output.numpy().tofile(cpu_output_path) + for qnn_output, cpu_output, meta in node_tensor_map.values(): + qnn_output_path = os.path.join(path, meta.edge_node_name + "_qnn.raw") + cpu_output_path = os.path.join(path, meta.edge_node_name + "_cpu.raw") + qnn_output.numpy().tofile(qnn_output_path) + cpu_output.numpy().tofile(cpu_output_path) print(f"Intermediate debugger raw files saved at: {path}") diff --git a/backends/qualcomm/debugger/qnn_intermediate_debugger.py b/backends/qualcomm/debugger/qnn_intermediate_debugger.py index 904dd4f6ccb..9189fc86db4 100644 --- a/backends/qualcomm/debugger/qnn_intermediate_debugger.py +++ b/backends/qualcomm/debugger/qnn_intermediate_debugger.py @@ -4,23 +4,26 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import copy import operator import os import warnings +from dataclasses import dataclass from enum import IntEnum +from typing import Tuple import torch - from executorch.backends.qualcomm._passes.layout_transform import LayoutTransform from executorch.backends.qualcomm.utils.constants import ( QCOM_AXIS_ORDER, + QCOM_DEBUG_HANDLE, QCOM_QUANT_ATTRS, QCOM_SCALE, - QCOM_TENSOR_NAME, QCOM_ZERO_POINT, ) from executorch.devtools import Inspector +from executorch.devtools.inspector._intermediate_output_capturer import ( + IntermediateOutputCapturer, +) from executorch.exir.sym_util import eval_shape from .format_outputs import export_csv, export_raw, export_svg @@ -33,46 +36,61 @@ class OutputFormat(IntEnum): DUMP_RAW = 2 -class IntermediateModule(torch.nn.Module): - """ - This class serves as an intermediate point and is inserted right after the call_function node. - It also saves some metadata such as scale, offset, etc. - Since we just want to check the intermediate output, we will directly return the value during the forward call. +class QNNIntermediateDebugger: + """This is a debugger tool to leverage IntermediateOutputCapturer to dump CPU intermediate results + and compare it with QNN's intermediate output to identify any QNN accuracy issues. """ - def __init__( - self, - module_name: str, - qnn_tensor_name: str, - node_name: str, - scale: float, - zero_point: int, - revert_order: bool = None, - ): - super().__init__() - self.module_name = module_name - self.qnn_tensor_name = qnn_tensor_name - self.node_name = node_name - self.scale = scale - self.zero_point = zero_point - self.revert_order = revert_order + @dataclass(frozen=True) + class QnnDebugMetaData: + """ + Summary: Meta data that will be used later for CPU V.S. QNN comparison + handle_id: handle_id under node.meta. + scale: Scale of the node for quant model, None for FP model. + zero_point: Zero point of the node for quant model, None for FP model. + is_qcom_layout: Taking 4D tensor as example, whether the node passed to QNN is in layout of NCHW (pytorch layout) or NHWC(qcom layout). + This is directly related is QCOM_AXIS_ORDER meta added during LayoutTransform pass. + edge_node_name: The node.name during edge ir. + """ - def forward(self, x): - return x + handle_id: int + scale: float + zero_point: int + is_qcom_layout: bool + edge_node_name: str + def __init__(self, keep_qnn_layout: bool = False): + """ -class QNNIntermediateDebugger: - """This is a debugger tool capable of retrieving intermediate results for CPU edge EP. - We can further compare these with QNN's intermediate output to identify any QNN accuracy issues. - """ + Args: + keep_qnn_layout (bool, optional): For general usage, keep this as False. + When comparing CPU result with QNN result, taking 4D tensors as example, + some QNN output is in NHWC format, while CPU is in NCHW format. + If turned to true, debugger will compare QNN using NHWC format. + QNN quantized outputs will also remain quantized instead of FP. + Please notice this will cause significant mismatch between QNN and CPU. + This feature is enabled for internal usage when dumping RAW files. - def __init__(self): - self.intermediate_outputs = {} + """ + self.keep_qnn_layout = keep_qnn_layout + self.golden_intermediate_outputs = None + self.node_tensor_map = None - def set_edge_module(self, edge_module: torch.fx.graph_module.GraphModule): - self.orig_edge = copy.deepcopy(edge_module) - self.intermediate_output_module = self._insert_intermediate_module( - copy.deepcopy(edge_module) + if self.keep_qnn_layout: + warnings.warn( + "[QNN Delegate Debugger]: keep_qnn_layout is not recommended for general use case. " + "QNN and CPU has different dtype(FP V.S. Quantized) and data formats(NCHW V.S. NHWC) in a lot of cases.", + stacklevel=1, + ) + + def capture_golden(self, sample_input: Tuple[torch.Tensor]): + if self.golden_intermediate_outputs: + warnings.warn( + "[QNN Delegate Debugger]: Golden is already captured. Override the previous result. Please ensure this is intentional.", + stacklevel=2, + ) + self.golden_intermediate_outputs = ( + self.intermediate_golden_capturer.run_and_capture(sample_input) ) def generate_results( @@ -82,43 +100,39 @@ def generate_results( output_format: OutputFormat, inspector: Inspector, evaluator: MetricEvaluatorBase = None, - keep_qnn_layout: bool = False, ): assert isinstance( output_format, OutputFormat - ), "output_format passed in is not an instance of OutputFormat" + ), "[QNN Delegate Debugger]: output_format passed in is not an instance of OutputFormat" os.makedirs(path, exist_ok=True) - if keep_qnn_layout: - warnings.warn( - "[QNN Delegate Debugger]: keep_qnn_layout is not recommended for general use case. " - "QNN and CPU has different dtype(FP V.S. Quantized) and data formats(NCHW V.S. NHWC) in a lot of cases.", - stacklevel=1, - ) - # Due to users can switch between keep_qnn_layout between generate_results, rematch this every time. - # Make this a class variable if repeat matching is taking too long and handle keep_qnn_layout. - node_tensor_map = self._match_tensors( - inspector=inspector, - keep_qnn_layout=keep_qnn_layout, - ) + # When use calls this function multiple times, only match tensor during 1st time. + if self.node_tensor_map is None: + self.node_tensor_map = self._match_tensors( + inspector=inspector, + ) if output_format == OutputFormat.SVG_GRAPHS: - assert evaluator is not None, "Please provide an evaluator." + assert ( + evaluator is not None + ), "[QNN Delegate Debugger]: Please provide an evaluator." export_svg( title=title, path=path, evaluator=evaluator, - edge_module=self.orig_edge, - node_tensor_map=node_tensor_map, + edge_module=self.edge_module, + node_tensor_map=self.node_tensor_map, ) elif output_format == OutputFormat.CSV_FILES: - assert evaluator is not None, "Please provide an evaluator." + assert ( + evaluator is not None + ), "[QNN Delegate Debugger]: Please provide an evaluator." export_csv( title=title, path=path, evaluator=evaluator, - edge_module=self.orig_edge, - node_tensor_map=node_tensor_map, + edge_module=self.edge_module, + node_tensor_map=self.node_tensor_map, ) elif output_format == OutputFormat.DUMP_RAW: warnings.warn( @@ -132,8 +146,7 @@ def generate_results( ) export_raw( path=path, - edge_module=self.intermediate_output_module, - node_tensor_map=node_tensor_map, + node_tensor_map=self.node_tensor_map, ) else: warnings.warn( @@ -142,108 +155,11 @@ def generate_results( ) return - def _insert_intermediate_module( # noqa: C901 - self, edge_module: torch.fx.graph_module.GraphModule - ): - """ - This feature is for intermediate tensor dump on the host CPU. - After we get an edge GraphModule, we insert submodule between each call_function node, - and we register forward hooks to store the intermediate results. - We have to use the edge GraphModule because this is the graph closest to what QNN is executing - while still being a valid graph to ExecuTorch. - - Args: - edge_module (exir.ExirExportedProgram): A deep copy of edge ir graph module. - We need to deep copy so we don't mess up the original edge_ep. - Returns: - exir.ExirExportedProgram: A deep copy of edge graph_module with intermediate modules inserted. + def _process_qnn_output( + self, qnn_output: torch.tensor, meta: QnnDebugMetaData + ) -> torch.tensor: """ - - def hook_fn(module, input, output): - meta = {} - meta[QCOM_TENSOR_NAME] = module.qnn_tensor_name - meta["node_name"] = module.node_name - meta[QCOM_SCALE] = module.scale - meta[QCOM_ZERO_POINT] = module.zero_point - meta["revert_order"] = module.revert_order - meta["output"] = output # CPU output - - assert ( - module.qnn_tensor_name not in self.intermediate_outputs - ), f"{module.qnn_tensor_name} checked already, check if this is a potential error" - self.intermediate_outputs[module.qnn_tensor_name] = meta - - graph = edge_module.graph - module_count = 0 - for node in graph.nodes: - if node.op == "call_function": - module_name = f"intermediate_module_{module_count}" - module_count += 1 - with graph.inserting_after(node): - scale = None - zero_point = None - if QCOM_QUANT_ATTRS in node.meta: - scale = node.meta[QCOM_QUANT_ATTRS][QCOM_SCALE] - zero_point = node.meta[QCOM_QUANT_ATTRS][QCOM_ZERO_POINT] - - revert_order = QCOM_AXIS_ORDER in node.meta - - if node.target == operator.getitem: - index = node.args[1] - # Ex: topk -> intermediate_module -> get_item - src_node = node.args[0].args[0] - qnn_tensor_name = src_node.meta[QCOM_TENSOR_NAME][index] - elif any(user.target == operator.getitem for user in node.users): - # For cases like topK, qnn_tensor_name is stored in get_item instead of source_node itself. - assert all( - user.target == operator.getitem for user in node.users - ), "Expect all users to be get_item node" - qnn_tensor_name = node.name - elif QCOM_TENSOR_NAME in node.meta: - assert ( - len(node.meta[QCOM_TENSOR_NAME]) == 1 - ), "Expecting a single qnn_tensor name but get more than 1." - qnn_tensor_name = node.meta[QCOM_TENSOR_NAME][0] - else: - # Unused - qnn_tensor_name = node.name - - obs = IntermediateModule( - module_name=module_name, - qnn_tensor_name=qnn_tensor_name, - node_name=node.name, - scale=scale, - zero_point=zero_point, - revert_order=revert_order, - ) - setattr( - edge_module, - module_name, - obs, - ) - new_obs = graph.create_node("call_module", module_name, (node,), {}) - orig_users = list(node.users.keys()) - for user_node in orig_users: - if user_node is new_obs: - continue - user_node.replace_input_with(node, new_obs) - - # Register hooks for all intermediate layers - for ( - _, - layer, - ) in edge_module.named_modules(): - if isinstance(layer, IntermediateModule): - layer.register_forward_hook(hook_fn) - - graph.eliminate_dead_code() - edge_module.recompile() - - return edge_module - - def _process_qnn_output(self, qnn_output: torch.tensor, meta: dict) -> torch.tensor: - """ - QNN intermediate results are all quantized. + QNN intermediate results could be quantized. We need to dequantize them to match CPU float values. Additionally, we need to revert the layout format for layout-sensitive nodes. @@ -255,62 +171,83 @@ def _process_qnn_output(self, qnn_output: torch.tensor, meta: dict) -> torch.ten torch.tensor: Processed tensor that should have same dtype and shape as CPU tensors. """ qnn_output = qnn_output.to(torch.float32) - if meta[QCOM_SCALE] is not None: - scale = meta[QCOM_SCALE] - zero_point = meta[QCOM_ZERO_POINT] + if meta.scale is not None: + scale = meta.scale + zero_point = meta.zero_point qnn_output = ( qnn_output.sub(zero_point).mul(scale).to(torch.float32).contiguous() ) - if meta["revert_order"]: + if meta.is_qcom_layout: axis_order = LayoutTransform.get_axis_order( eval_shape(qnn_output.shape), reverse=True ) qnn_output = qnn_output.permute(axis_order) return qnn_output - def _match_tensors(self, inspector: Inspector, keep_qnn_layout: bool = False): + def _match_tensors(self, inspector: Inspector): """ Map QNN tensors back to CPU tensors. - Create a map using the node name as the key and (preprocessed/postprocessed QNN tensor, CPU tensor, meta) as the value. + Create a map using the edge_node_name as the key and (preprocessed/postprocessed QNN tensor, CPU tensor, QnnDebugMetaData) as the value. We need meta because it holds values such as scale, offset, layout sensitivity, etc. Args: inspector (Inspector): Inspector that parse QNN runtime intermediate outputs - keep_qnn_layout (bool): If true, store QNN outputs in NHWC format. Not recommended for general users. Returns: - A dict storing {node_name : tuple(qnn_output, cpu_output, meta_info)} - Meta_info is the info stored during forward hook_fn. + A dict storing {edge_node_name : tuple(qnn_output, cpu_output, QnnDebugMetaData)} """ - # node_tensor_map {key: tuple(qnn_output, cpu_output, meta_info)} + # node_tensor_map {edge_node_name: tuple(qnn_output, cpu_output, QnnDebugMetaData)} node_tensor_map = {} # OPs that only exists in QNN but not CPU Golden unmatched_qnn_tensors = [] # E.g.: DELEGATE_CALL (This is the model input data), 'Method::execute' ignored_events = [] - # Collected with forward hook - intermediate_outputs = self.intermediate_outputs + for event_block in inspector.event_blocks: if event_block.name == "Execute": for event in event_block.events: # If user enables profiling and dump intermediate outputs the same time, we need to skip the profiling event if event.perf_data is not None and event.is_delegated_op: continue - if meta := intermediate_outputs.get(event.name): - node_name = meta["node_name"] - cpu_output = meta["output"] - qnn_output = ( - event.debug_data[0] - if keep_qnn_layout - else self._process_qnn_output(event.debug_data[0], meta) - ) - node_tensor_map[node_name] = ( - qnn_output, - cpu_output, - meta, - ) + if ( + event.name.isdigit() + and (int(event.name),) in self.golden_intermediate_outputs + ): + + debug_handle = (int(event.name),) + cpu_output = self.golden_intermediate_outputs[debug_handle] + if torch.is_tensor(cpu_output): + cpu_output = [cpu_output] + + node_meta = self.node_meta_map[debug_handle] + + # We can't do assertions here because of some edge cases. + # Ex: max_pool2d has 2 outputs. However, QNN only has 1 output and graph only use output[0]. + # CPU gen an extra output that's never used. + if len(cpu_output) != len(event.debug_data): + warnings.warn( + f"[QNN Delegate Debugger]: Number of output does not match." + f"CPU has {len(cpu_output)} outputs. QNN has {len(event.debug_data)} outputs, possibly due to OP generating multiple outputs and some are unused." + f"Check following node_meta info to see if this is desired: {node_meta}", + stacklevel=1, + ) + for i, event_data in enumerate(event.debug_data): + qnn_output = ( + event_data + if self.keep_qnn_layout + else self._process_qnn_output(event_data, node_meta[i]) + ) + edge_node_name = node_meta[i].edge_node_name + assert ( + edge_node_name not in node_tensor_map + ), f"[QNN Delegate Debugger]: Duplicate tensor name found when visiting {edge_node_name}" + node_tensor_map[edge_node_name] = ( + qnn_output, + cpu_output[i], + node_meta[i], + ) else: ( unmatched_qnn_tensors.append(event.name) @@ -319,10 +256,58 @@ def _match_tensors(self, inspector: Inspector, keep_qnn_layout: bool = False): ) warnings.warn( - f"The following events are ignored: {ignored_events}", stacklevel=1 + f"[QNN Delegate Debugger]: The following events are ignored: {ignored_events}", + stacklevel=1, ) warnings.warn( - f"The following QNN OPs are missing CPU reference. OPs added during qnn_preprocess will not have CPU reference. Please ensure the operations below are created during qnn_preprocess. {unmatched_qnn_tensors}", + f"[QNN Delegate Debugger]: The following QNN OPs are missing CPU reference. OPs added during qnn_preprocess will not have CPU reference. Please ensure the operations below are created during qnn_preprocess. {unmatched_qnn_tensors}", stacklevel=1, ) return node_tensor_map + + def _set_edge_module( + self, edge_module: torch.fx.graph_module.GraphModule, debug_handle_map: dict + ): + self.edge_module = edge_module + self.intermediate_golden_capturer = IntermediateOutputCapturer( + module=self.edge_module + ) + self.debug_handle_map = debug_handle_map + self.node_meta_map = {} + for node in self.edge_module.graph.nodes: + + # For multi output ops like topk, + # meta info is stored in getitem, so skip source node itself. + if any(user.target == operator.getitem for user in node.users): + # Assume if a node user is getitem, all users are getitem + assert all( + user.target == operator.getitem for user in node.users + ), "[QNN Delegate Debugger]: Expect all users to be get_item node" + continue + + if handle_id := node.meta.get(QCOM_DEBUG_HANDLE): + scale = None + zero_point = None + is_qcom_layout = QCOM_AXIS_ORDER in node.meta + if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): + scale = quant_attrs[QCOM_SCALE] + zero_point = quant_attrs[QCOM_ZERO_POINT] + + debug_meta = QNNIntermediateDebugger.QnnDebugMetaData( + handle_id=handle_id, + scale=scale, + zero_point=zero_point, + is_qcom_layout=is_qcom_layout, + edge_node_name=node.name, + ) + if node.target == operator.getitem: + output_idx = node.args[1] + if (handle_id,) in self.node_meta_map: + self.node_meta_map[(handle_id,)][output_idx] = debug_meta + else: + self.node_meta_map[(handle_id,)] = {output_idx: debug_meta} + else: + assert ( + handle_id, + ) not in self.node_meta_map, f"[QNN Delegate Debugger]: Duplicate handle_id {handle_id} found when visiting {node.name}." + self.node_meta_map[(handle_id,)] = {0: debug_meta} diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py index c0351b01ed6..223d9fe540e 100644 --- a/backends/qualcomm/qnn_preprocess.py +++ b/backends/qualcomm/qnn_preprocess.py @@ -19,7 +19,10 @@ from executorch.backends.qualcomm.serialization.qc_schema_serialize import ( flatbuffer_to_option, ) -from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER +from executorch.backends.qualcomm.utils.constants import ( + QCOM_AXIS_ORDER, + QCOM_DEBUG_HANDLE, +) from executorch.backends.qualcomm.utils.qnn_manager_lifecycle import ( get_current_qnn_manager, ) @@ -28,6 +31,7 @@ CompileSpec, PreprocessResult, ) +from executorch.exir.backend.utils import DelegateMappingBuilder from torch.export.exported_program import ExportedProgram DEFAULT_DEBUG_HANDLE = 65535 @@ -138,7 +142,7 @@ def preprocess( ) @staticmethod - def preprocess_multimethod( + def preprocess_multimethod( # noqa: C901 edge_programs: Dict[str, List[ExportedProgram]], compile_specs: Dict[str, List[List[CompileSpec]]], ) -> PreprocessResult: @@ -161,8 +165,9 @@ def preprocess_multimethod( qnn_manager = get_current_qnn_manager( option.backend_options.backend_type, compile_spec ) + debug_handle_builder = DelegateMappingBuilder(generated_identifiers=False) for i in range(num_sub_graphs): - # e.g. 2 methods (x, y) with 3 partitions + # e.g. 2 methods (x, y) with 3 subgraphs(partitions) # > context_binary_0: [x.subgraph_0, y.subgraph_0] # > context_binary_1: [x.subgraph_1, y.subgraph_1] # > context_binary_2: [x.subgraph_2, y.subgraph_2] @@ -176,6 +181,13 @@ def preprocess_multimethod( option.op_package_options.op_package_infos, option.use_mha2sha, ) + if qnn_manager.IsTensorDump(): + for node in programs[i].graph.nodes: + if handle_id := node.meta.get(QCOM_DEBUG_HANDLE): + debug_handle_builder.insert_delegate_mapping_entry( + handles=handle_id, + identifier=node.name, + ) if isinstance(py_op_wrappers, bytes): ctx_binary_list.append(py_op_wrappers) else: @@ -185,7 +197,6 @@ def preprocess_multimethod( for py_op_wrapper in py_op_wrappers ] ) - if len(py_op_wrapper_list) == len(edge_programs.values()): qnn_context_binary = qnn_manager.Compile( graph_names, py_op_wrapper_list @@ -204,15 +215,18 @@ def preprocess_multimethod( all_processed_results[key].append( PreprocessResult( processed_bytes=bytes(qnn_context_binary), - debug_handle_map={}, + debug_handle_map=debug_handle_builder.get_delegate_mapping(), ) ) + elif len(ctx_binary_list) == len(edge_programs.values()): for i, key in enumerate(edge_programs.keys()): all_processed_results[key].append( - PreprocessResult(processed_bytes=ctx_binary_list[i]) + PreprocessResult( + processed_bytes=ctx_binary_list[i], + debug_handle_map=debug_handle_builder.get_delegate_mapping(), + ) ) else: raise RuntimeError("Hybrid compilation is not supported") - return all_processed_results diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index bd4a5fb071e..928659ca098 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -178,8 +178,18 @@ def annotate_atan(node: Node, quantization_config: QuantizationConfig) -> None: def annotate_topk(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]): return - # We can use single_in_single_out since we don't want to quantize indices output - annotate_single_in_single_out(node, quantization_config) + + input_qspec_map = {} + if _is_float_tensor(node.args[0]): + input_act = node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = quantization_config.input_activation + + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None: diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp index 17dc6bf4e19..9e91cbdb487 100644 --- a/backends/qualcomm/runtime/QnnManager.cpp +++ b/backends/qualcomm/runtime/QnnManager.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -389,6 +390,11 @@ Error QnnManager::Execute( if (IsTensorDump()) { // TODO: Need to handle the graph which is partitioned. // Maybe we could use graph name. + + // Parsing out the debug handle id + std::regex re("_debugID_(\\d+)"); + std::smatch match; + uint32_t debug_handle_id; for (std::size_t out_idx = 0; out_idx < output_tensor_structs.size(); ++out_idx) { const Qnn_Tensor_t& output_tensor = output_tensor_structs[out_idx]; @@ -403,13 +409,36 @@ Error QnnManager::Execute( qnn_dtype_to_scalar_type_[QNN_TENSOR_VER_PTR(output_tensor) ->dataType]); - executorch::runtime::event_tracer_log_output_delegate< - executorch::aten::Tensor>( - event_tracer, - QNN_TENSOR_VER_PTR(output_tensor)->name, - /*delegate_debug_id=*/ - static_cast(-1), - *dump_tensor); + std::string qnn_tensor_name = + std::string(QNN_TENSOR_VER_PTR(output_tensor)->name); + if (std::regex_search(qnn_tensor_name, match, re)) { + debug_handle_id = static_cast(std::stoul(match[1].str())); + + QNN_EXECUTORCH_LOG_INFO( + "Found the debug_handle id %d from qnn_tensor_name: %s", + debug_handle_id, + QNN_TENSOR_VER_PTR(output_tensor)->name); + executorch::runtime::event_tracer_log_output_delegate< + executorch::aten::Tensor>( + event_tracer, + /*name*/ + nullptr, + /*delegate_debug_id=*/ + static_cast(debug_handle_id), + *dump_tensor); + } else { + QNN_EXECUTORCH_LOG_INFO( + "Unable to find the debug_handle id from qnn_tensor_name: %s. Use qnn_tensor_name as key instead.", + QNN_TENSOR_VER_PTR(output_tensor)->name); + executorch::runtime::event_tracer_log_output_delegate< + executorch::aten::Tensor>( + event_tracer, + /*name*/ + QNN_TENSOR_VER_PTR(output_tensor)->name, + /*delegate_debug_id=*/ + static_cast(-1), + *dump_tensor); + } } } diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 2b73e0c6dfb..ca8e267338d 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -2190,6 +2190,7 @@ def __init__(self): self.idx_source = torch.rand(10, 3) def forward(self, x): + x = torch.nn.functional.relu(x) a, b = torch.topk(x, 3) return a + self.idx_source[b] diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index aa3f28b34ee..793764401a1 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -4571,6 +4571,7 @@ def setUp(self): ) def test_qnn_backend_dump_intermediate_outputs_topk(self): + torch.manual_seed(8) backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( soc_model=self.chipset_table[TestQNN.model], @@ -4584,7 +4585,7 @@ def test_qnn_backend_dump_intermediate_outputs_topk(self): sample_input, expected_partitions=1, expected_intermediate_events=7, - expected_compared_events=5, + expected_compared_events=6, ) def test_qnn_backend_dump_intermediate_outputs_simple_model(self): @@ -5175,6 +5176,7 @@ def test_qnn_backend_dump_intermediate_outputs_simple_model(self): ) def test_qnn_backend_dump_intermediate_outputs_topk(self): + torch.manual_seed(8) backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( soc_model=self.chipset_table[TestQNN.model], @@ -5188,8 +5190,8 @@ def test_qnn_backend_dump_intermediate_outputs_topk(self): module, sample_input, expected_partitions=1, - expected_intermediate_events=8, - expected_compared_events=5, + expected_intermediate_events=9, + expected_compared_events=6, ) def test_qnn_backend_dynamic_shape(self): diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index f4b9339e1c2..edb9af64b21 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -310,8 +310,9 @@ def validate_intermediate_tensor(): inspector = Inspector( etdump_path=etdump_path, debug_buffer_path=debug_output_path ) + node_tensor_map = qnn_intermediate_debugger._match_tensors( - inspector=inspector, keep_qnn_layout=False + inspector=inspector ) self.assertEqual( len(node_tensor_map), @@ -321,7 +322,7 @@ def validate_intermediate_tensor(): # Compare accuracy for each layer for _, value in node_tensor_map.items(): self._assert_outputs_equal( - value[0].to(torch.float32), value[1].to(torch.float32) + (value[0].to(torch.float32),), (value[1].to(torch.float32),) ) for event_block in inspector.event_blocks: if event_block.name == "Execute": @@ -521,15 +522,18 @@ def lower_module_and_test_output( assert len(lowered_module_nodes) == 1, "Length not correct" lowered_module_node = lowered_module_nodes[0] - lower_module = getattr( + lowered_module = getattr( delegated_program.exported_program().graph_module, lowered_module_node.name, ) - edge_module = lower_module.original_module.module() + edge_module = lowered_module.original_module.module() qnn_intermediate_debugger = QNNIntermediateDebugger() - qnn_intermediate_debugger.set_edge_module(edge_module=edge_module) - qnn_intermediate_debugger.intermediate_output_module(*sample_inputs) + qnn_intermediate_debugger._set_edge_module( + edge_module=edge_module, + debug_handle_map=lowered_module.meta["debug_handle_map"], + ) + qnn_intermediate_debugger.capture_golden(sample_input=sample_inputs) exec_prog = delegated_program.to_executorch( exir.ExecutorchBackendConfig( diff --git a/backends/qualcomm/utils/constants.py b/backends/qualcomm/utils/constants.py index 5a6e7570e82..908c959fe91 100644 --- a/backends/qualcomm/utils/constants.py +++ b/backends/qualcomm/utils/constants.py @@ -16,6 +16,7 @@ QCOM_BLOCK_SCALE_OFFSET = "block_scale_offset" QCOM_BLOCK_STORAGE_TYPE = "block_storage_type" QCOM_BYPASS_NODE = "bypass_node" +QCOM_DEBUG_HANDLE = "debug_handle" QCOM_DATA = "data" QCOM_DTYPE = "dtype" QCOM_ENCODING = "encoding" @@ -33,7 +34,6 @@ QCOM_SCALE = "scale" QCOM_SCALES = "scales" QCOM_SCALE_OFFSET = "scale_offset" -QCOM_TENSOR_NAME = "qnn_tensor_name" QCOM_ZERO_POINT = "zero_point" QCOM_ZERO_POINTS = "zero_points" QCOM_PASS_ACTIVATE_KEY = "activate" diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 4a68f434895..f2f846c5c7e 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -392,6 +392,21 @@ def ensure_graph_specific_dict(value, graph_names): return value return {graph_name: value for graph_name in graph_names} + # Ensure if user is using intermediate debugger, user only lower 1 method. + # This restriction is caused by conflict handle_id among graphs. + # This could be resolved with generating random debug_id(e.g., uuid). + for compiler_spec in ( + compiler_specs.values() + if isinstance(compiler_specs, Dict) + else [compiler_specs] + ): + option = generate_qnn_executorch_option(compiler_spec) + obj_options = flatbuffer_to_option(option) + if obj_options.dump_intermediate_outputs and isinstance(module, Dict): + assert ( + len(module) == 1 + ), "Intermediate Tensor Dump does not support multi-methods." + if not isinstance(module, dict): module = {"forward": module} diff --git a/examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py b/examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py index 727c94900ca..92e0ca97be1 100644 --- a/examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py +++ b/examples/qualcomm/util_scripts/qnn_intermediate_debugger_demo.py @@ -63,6 +63,7 @@ def main(args): pte_filename = "ic3_qnn_debug" instance = InceptionV3Model() source_model = instance.get_eager_model().eval() + # Init our QNNIntermediateDebugger and pass it in to build_executorch_binary(). qnn_intermediate_debugger = QNNIntermediateDebugger() build_executorch_binary( @@ -129,12 +130,11 @@ def validate_intermediate_tensor(): debug_buffer_path=f"{args.artifact}/debug_output.bin", ) - edge_result = qnn_intermediate_debugger.intermediate_output_module( - *(inputs[0]) - )[0] + qnn_intermediate_debugger.capture_golden(*(inputs[0])) # Optional: Ensures that edge module accuracy aligns with nn.Module with torch.no_grad(): + edge_result = qnn_intermediate_debugger.edge_module(*(inputs[0]))[0] source_result = source_model(*(inputs[0])) score = torch.nn.functional.cosine_similarity( edge_result.flatten(), source_result.flatten(), dim=0 diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index ca1d655c0db..513bb4cf400 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -493,11 +493,14 @@ def build_executorch_binary( ), "Graph with partitions are currently unsupported." lowered_module_node = lowered_module_nodes[0] - lower_module = getattr( + lowered_module = getattr( edge_prog_mgr.exported_program().graph_module, lowered_module_node.name ) - edge_module = lower_module.original_module.module() - qnn_intermediate_debugger.set_edge_module(edge_module=edge_module) + edge_module = lowered_module.original_module.module() + qnn_intermediate_debugger._set_edge_module( + edge_module=edge_module, + debug_handle_map=lowered_module.meta["debug_handle_map"], + ) executorch_config = ExecutorchBackendConfig( # For shared buffer, user must pass the memory address