Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 187 additions & 1 deletion backends/arm/test/misc/test_vgf_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from types import SimpleNamespace
from typing import cast
from unittest import mock

import pytest

Expand All @@ -14,7 +16,14 @@
clear_registered_pass_insertions,
PassInsertions,
)
from executorch.backends.arm.vgf import backend as vgf_backend, VgfCompileSpec

from executorch.backends.arm.vgf import backend, backend as vgf_backend, VgfCompileSpec
from executorch.backends.arm.vgf.backend import (
_copy_failure_artifacts,
_format_repro_command,
_replace_converter_input_path,
vgf_compile,
)
from executorch.exir.backend.backend_details import PreprocessResult
from executorch.exir.pass_base import ExportPass
from torch.export.exported_program import ExportedProgram
Expand Down Expand Up @@ -105,3 +114,180 @@ def _raise(*args, **kwargs):
assert _registry_state() == original_registry
finally:
clear_registered_pass_insertions()


def test_format_repro_command_quotes_shell_metacharacters():
command = [
"model-converter",
"--flag=value with spaces",
"-i",
"input file.tosa",
"-o",
"output file.vgf",
]

formatted = _format_repro_command(command)

assert formatted == (
"model-converter "
"'--flag=value with spaces' "
"-i "
"'input file.tosa' "
"-o "
"'output file.vgf'"
)


def test_replace_converter_input_path_replaces_input_after_i():
command = [
"model-converter",
"--some-flag",
"-i",
"original.tosa",
"-o",
"output.vgf",
]

replaced = _replace_converter_input_path(command, "preserved.tosa")

assert replaced == [
"model-converter",
"--some-flag",
"-i",
"preserved.tosa",
"-o",
"output.vgf",
]
assert command[3] == "original.tosa"


def test_copy_failure_artifacts_returns_none_without_artifact_path(tmp_path):
tosa_path = tmp_path / "input.tosa"
tosa_path.write_bytes(b"tosa bytes")

copied_path = _copy_failure_artifacts(
str(tosa_path),
artifact_path=None,
tag_name="delegate_0",
)

assert copied_path is None


def test_copy_failure_artifacts_copies_tosa_with_tag_name(tmp_path):
tosa_path = tmp_path / "input.tosa"
artifact_path = tmp_path / "artifacts"
tosa_path.write_bytes(b"tosa bytes")

copied_path = _copy_failure_artifacts(
str(tosa_path),
str(artifact_path),
tag_name="delegate_0",
)

assert copied_path == os.path.join(
str(artifact_path),
"failed_model_converter_input_delegate_0.tosa",
)
assert os.path.exists(copied_path)
assert open(copied_path, "rb").read() == b"tosa bytes"


def test_copy_failure_artifacts_copies_tosa_without_tag_name(tmp_path):
tosa_path = tmp_path / "input.tosa"
artifact_path = tmp_path / "artifacts"
tosa_path.write_bytes(b"tosa bytes")

copied_path = _copy_failure_artifacts(
str(tosa_path),
str(artifact_path),
tag_name="",
)

assert copied_path == os.path.join(
str(artifact_path),
"failed_model_converter_input.tosa",
)
assert os.path.exists(copied_path)
assert open(copied_path, "rb").read() == b"tosa bytes"


@mock.patch("executorch.backends.arm.vgf.backend.model_converter_env")
@mock.patch("executorch.backends.arm.vgf.backend.require_model_converter_binary")
@mock.patch("executorch.backends.arm.vgf.backend.subprocess.run")
def test_vgf_compile_failure_includes_repro_command_and_copies_tosa(
mock_run,
mock_require_model_converter_binary,
mock_model_converter_env,
tmp_path,
):
artifact_path = tmp_path / "artifacts"

mock_require_model_converter_binary.return_value = "model-converter"
mock_model_converter_env.return_value = {"PATH": "/test/bin"}
mock_run.side_effect = backend.subprocess.CalledProcessError(
returncode=1,
cmd=["model-converter"],
output=b"converter stdout",
stderr=b"converter stderr",
)

with pytest.raises(RuntimeError) as exc_info:
vgf_compile(
b"serialized tosa",
["--flag=value with spaces"],
artifact_path=str(artifact_path),
tag_name="delegate_0",
)

copied_tosa_path = os.path.join(
str(artifact_path),
"failed_model_converter_input_delegate_0.tosa",
)

assert os.path.exists(copied_tosa_path)
assert open(copied_tosa_path, "rb").read() == b"serialized tosa"

error = str(exc_info.value)
assert "Vgf compiler failed." in error
assert "Repro command:" in error
assert "model-converter '--flag=value with spaces' -i" in error
assert copied_tosa_path in error
assert " -o " in error
assert "Stderr:\nconverter stderr" in error
assert "Stdout:\nconverter stdout" in error


@mock.patch("executorch.backends.arm.vgf.backend.model_converter_env")
@mock.patch("executorch.backends.arm.vgf.backend.require_model_converter_binary")
@mock.patch("executorch.backends.arm.vgf.backend.subprocess.run")
def test_vgf_compile_failure_includes_temp_repro_command_without_artifact_path(
mock_run,
mock_require_model_converter_binary,
mock_model_converter_env,
):
mock_require_model_converter_binary.return_value = "model-converter"
mock_model_converter_env.return_value = {"PATH": "/test/bin"}
mock_run.side_effect = backend.subprocess.CalledProcessError(
returncode=1,
cmd=["model-converter"],
output=b"converter stdout",
stderr=b"converter stderr",
)

with pytest.raises(RuntimeError) as exc_info:
vgf_compile(
b"serialized tosa",
["--some-flag"],
artifact_path=None,
tag_name="delegate_0",
)

error = str(exc_info.value)
assert "Vgf compiler failed." in error
assert "Repro command:" in error
assert "model-converter --some-flag -i" in error
assert "output_delegate_0.tosa.vgf" in error
assert "failed_model_converter_input_delegate_0.tosa" not in error
assert "Stderr:\nconverter stderr" in error
assert "Stdout:\nconverter stdout" in error
65 changes: 61 additions & 4 deletions backends/arm/vgf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import logging
import os # nosec B404 - used alongside subprocess for tool invocation
import shlex
import shutil
import subprocess # nosec B404 - required to drive external converter CLI
import tempfile
Expand Down Expand Up @@ -251,6 +252,52 @@ def preprocess(
return PreprocessResult(processed_bytes=binary)


def _format_repro_command(command: List[str]) -> str:
"""Return a shell-safe command string for reproducing converter failures."""
return " ".join(shlex.quote(arg) for arg in command)


def _copy_failure_artifacts(
tosa_path: str,
artifact_path: str | None,
tag_name: str,
) -> str | None:
"""Copy the failing TOSA input to the artifact directory, if configured.

Args:
tosa_path: Temporary TOSA flatbuffer passed to model-converter.
artifact_path: User-configured intermediate artifact directory.
tag_name: Optional delegation tag used to disambiguate artifacts.

Returns:
Path to the copied TOSA file, or None if no artifact path was configured.

"""
if not artifact_path:
return None

os.makedirs(artifact_path, exist_ok=True)

suffix = f"_{tag_name}" if tag_name else ""
failure_tosa_path = os.path.join(
artifact_path,
f"failed_model_converter_input{suffix}.tosa",
)
shutil.copy2(tosa_path, failure_tosa_path)
return failure_tosa_path


def _replace_converter_input_path(
conversion_command: List[str],
input_path: str,
) -> List[str]:
"""Return a converter command that uses a preserved TOSA input path."""
input_flag_index = conversion_command.index("-i")
repro_command = list(conversion_command)
repro_command[input_flag_index + 1] = input_path
return repro_command


def vgf_compile(
tosa_flatbuffer: bytes,
compile_flags: List[str],
Expand Down Expand Up @@ -299,11 +346,21 @@ def vgf_compile(
env=model_converter_env(),
)
except subprocess.CalledProcessError as process_error:
conversion_command_str = " ".join(conversion_command)
failure_tosa_path = _copy_failure_artifacts(
tosa_path,
artifact_path,
tag_name,
)
repro_command = (
_replace_converter_input_path(conversion_command, failure_tosa_path)
if failure_tosa_path
else conversion_command
)
raise RuntimeError(
f"Vgf compiler ('{conversion_command_str}') failed with error:\n \
{process_error.stderr.decode()}\n \
Stdout:\n{process_error.stdout.decode()}"
"Vgf compiler failed.\n"
f"Repro command:\n {_format_repro_command(repro_command)}\n"
f"Stderr:\n{process_error.stderr.decode()}\n"
f"Stdout:\n{process_error.stdout.decode()}"
)

if artifact_path:
Expand Down
Loading