diff --git a/src/google/adk/code_executors/unsafe_local_code_executor.py b/src/google/adk/code_executors/unsafe_local_code_executor.py index 64752fffd5..d3ffb5eccb 100644 --- a/src/google/adk/code_executors/unsafe_local_code_executor.py +++ b/src/google/adk/code_executors/unsafe_local_code_executor.py @@ -14,14 +14,11 @@ from __future__ import annotations -from contextlib import redirect_stdout -import io import logging -import multiprocessing -import queue -import re -import traceback -from typing import Any +import os +import subprocess +import sys +import tempfile from pydantic import Field from typing_extensions import override @@ -34,26 +31,6 @@ logger = logging.getLogger('google_adk.' + __name__) -def _execute_in_process( - code: str, globals_: dict[str, Any], result_queue: multiprocessing.Queue -) -> None: - """Executes code in a separate process and puts result in queue.""" - stdout = io.StringIO() - error = None - try: - with redirect_stdout(stdout): - exec(code, globals_, globals_) - except BaseException: - error = traceback.format_exc() - result_queue.put((stdout.getvalue(), error)) - - -def _prepare_globals(code: str, globals_: dict[str, Any]) -> None: - """Prepare globals for code execution, injecting __name__ if needed.""" - if re.search(r"if\s+__name__\s*==\s*['\"]__main__['\"]", code): - globals_['__name__'] = '__main__' - - class UnsafeLocalCodeExecutor(BaseCodeExecutor): """A code executor that unsafely execute code in the current local context.""" @@ -81,34 +58,31 @@ def execute_code( code_execution_input: CodeExecutionInput, ) -> CodeExecutionResult: logger.debug('Executing code:\n```\n%s\n```', code_execution_input.code) - # Execute the code. - globals_ = {} - _prepare_globals(code_execution_input.code, globals_) - - ctx = multiprocessing.get_context('spawn') - result_queue = ctx.Queue() - process = ctx.Process( - target=_execute_in_process, - args=(code_execution_input.code, globals_, result_queue), - daemon=True, - ) - process.start() - output = '' - error = '' - try: - output, err = result_queue.get(timeout=self.timeout_seconds) - process.join() - if err: - error = err - except queue.Empty: - process.terminate() - process.join() - error = f'Code execution timed out after {self.timeout_seconds} seconds.' + with tempfile.TemporaryDirectory() as temp_dir: + code_path = os.path.join(temp_dir, 'main.py') + with open(code_path, 'w', encoding='utf-8') as f: + f.write(code_execution_input.code) + + output = '' + error = '' + try: + result = subprocess.run( + [sys.executable, code_path], + capture_output=True, + text=True, + timeout=self.timeout_seconds, + cwd=temp_dir, + ) + output = result.stdout + if result.returncode != 0: + error = result.stderr + except subprocess.TimeoutExpired as e: + output = e.stdout if e.stdout else '' + error = f'Code execution timed out after {self.timeout_seconds} seconds.' + except Exception as e: + error = str(e) - # Collect the final result. - result_queue.close() - result_queue.join_thread() return CodeExecutionResult( stdout=output, stderr=error, diff --git a/src/google/adk/evaluation/local_eval_set_results_manager.py b/src/google/adk/evaluation/local_eval_set_results_manager.py index c6da638abe..656d9f411e 100644 --- a/src/google/adk/evaluation/local_eval_set_results_manager.py +++ b/src/google/adk/evaluation/local_eval_set_results_manager.py @@ -16,6 +16,7 @@ import logging import os +import re from typing_extensions import override @@ -67,6 +68,7 @@ def get_eval_set_result( self, app_name: str, eval_set_result_id: str ) -> EvalSetResult: """Returns an EvalSetResult identified by app_name and eval_set_result_id.""" + self._validate_id("Eval Set Result ID", eval_set_result_id) # Load the eval set result file data. maybe_eval_result_file_path = ( os.path.join( @@ -97,4 +99,12 @@ def list_eval_set_results(self, app_name: str) -> list[str]: return eval_result_files def _get_eval_history_dir(self, app_name: str) -> str: + self._validate_id("App Name", app_name) return os.path.join(self._agents_dir, app_name, _ADK_EVAL_HISTORY_DIR) + + def _validate_id(self, id_name: str, id_value: str): + pattern = r"^[a-zA-Z0-9_\-\.]+$" + if not bool(re.fullmatch(pattern, id_value)) or ".." in id_value: + raise ValueError( + f"Invalid {id_name}. {id_name} should have the `{pattern}` format and not contain `..`", + ) diff --git a/src/google/adk/evaluation/local_eval_sets_manager.py b/src/google/adk/evaluation/local_eval_sets_manager.py index 8d2290b911..3f2f0ca77f 100644 --- a/src/google/adk/evaluation/local_eval_sets_manager.py +++ b/src/google/adk/evaluation/local_eval_sets_manager.py @@ -201,7 +201,7 @@ def get_eval_set(self, app_name: str, eval_set_id: str) -> Optional[EvalSet]: try: eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id) return load_eval_set_from_file(eval_set_file_path, eval_set_id) - except FileNotFoundError: + except (FileNotFoundError, ValueError): return None @override @@ -211,8 +211,6 @@ def create_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet: Raises: ValueError: If Eval Set ID is not valid or an eval set already exists. """ - self._validate_id(id_name="Eval Set ID", id_value=eval_set_id) - # Define the file path new_eval_set_path = self._get_eval_set_file_path(app_name, eval_set_id) @@ -247,6 +245,7 @@ def list_eval_sets(self, app_name: str) -> list[str]: Raises: NotFoundError: If the eval directory for the app is not found. """ + self._validate_id("App Name", app_name) eval_set_file_path = os.path.join(self._agents_dir, app_name) eval_sets = [] try: @@ -266,6 +265,7 @@ def get_eval_case( self, app_name: str, eval_set_id: str, eval_case_id: str ) -> Optional[EvalCase]: """Returns an EvalCase if found; otherwise, None.""" + self._validate_id("Eval Case ID", eval_case_id) eval_set = self.get_eval_set(app_name, eval_set_id) if not eval_set: return None @@ -310,6 +310,8 @@ def delete_eval_case( self._save_eval_set(app_name, eval_set_id, updated_eval_set) def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str: + self._validate_id("App Name", app_name) + self._validate_id("Eval Set ID", eval_set_id) return os.path.join( self._agents_dir, app_name, @@ -317,10 +319,10 @@ def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str: ) def _validate_id(self, id_name: str, id_value: str): - pattern = r"^[a-zA-Z0-9_]+$" - if not bool(re.fullmatch(pattern, id_value)): + pattern = r"^[a-zA-Z0-9_\-\.]+$" + if not bool(re.fullmatch(pattern, id_value)) or ".." in id_value: raise ValueError( - f"Invalid {id_name}. {id_name} should have the `{pattern}` format", + f"Invalid {id_name}. {id_name} should have the `{pattern}` format and not contain `..`", ) def _write_eval_set_to_path(self, eval_set_path: str, eval_set: EvalSet): diff --git a/tests/unittests/evaluation/test_local_eval_set_results_manager.py b/tests/unittests/evaluation/test_local_eval_set_results_manager.py index 4647392628..5b2c873e29 100644 --- a/tests/unittests/evaluation/test_local_eval_set_results_manager.py +++ b/tests/unittests/evaluation/test_local_eval_set_results_manager.py @@ -174,3 +174,11 @@ def test_list_eval_set_results_empty(self): # No eval set results saved for the app results = self.manager.list_eval_set_results(self.app_name) assert results == [] + + def test_get_eval_history_dir_invalid_app_name(self): + with pytest.raises(ValueError, match="Invalid App Name"): + self.manager.list_eval_set_results("../invalid") + + def test_get_eval_set_result_invalid_id(self): + with pytest.raises(ValueError, match="Invalid Eval Set Result ID"): + self.manager.get_eval_set_result(self.app_name, "../invalid_id") diff --git a/tests/unittests/evaluation/test_local_eval_sets_manager.py b/tests/unittests/evaluation/test_local_eval_sets_manager.py index 3450fb9338..67e089a3db 100644 --- a/tests/unittests/evaluation/test_local_eval_sets_manager.py +++ b/tests/unittests/evaluation/test_local_eval_sets_manager.py @@ -390,11 +390,20 @@ def test_local_eval_sets_manager_create_eval_set_invalid_id( self, local_eval_sets_manager ): app_name = "test_app" - eval_set_id = "invalid-id" + eval_set_id = "invalid/id" with pytest.raises(ValueError, match="Invalid Eval Set ID"): local_eval_sets_manager.create_eval_set(app_name, eval_set_id) + def test_local_eval_sets_manager_create_eval_set_invalid_app_name( + self, local_eval_sets_manager + ): + app_name = "../test_app" + eval_set_id = "test_eval_set" + + with pytest.raises(ValueError, match="Invalid App Name"): + local_eval_sets_manager.create_eval_set(app_name, eval_set_id) + def test_local_eval_sets_manager_create_eval_set_already_exists( self, local_eval_sets_manager, mocker ):