diff --git a/.github/workflows/js-tests.yml b/.github/workflows/js-tests.yml index 0d56e8831..1cef1f8af 100644 --- a/.github/workflows/js-tests.yml +++ b/.github/workflows/js-tests.yml @@ -47,4 +47,5 @@ jobs: uv run pytest tests/test_languages/test_vitest_e2e.py -v uv run pytest tests/test_languages/test_javascript_e2e.py -v uv run pytest tests/test_languages/test_javascript_support.py -v + uv run pytest tests/test_languages/test_javascript_tracer.py -v uv run pytest tests/code_utils/test_config_js.py -v diff --git a/codeflash/languages/javascript/parse.py b/codeflash/languages/javascript/parse.py deleted file mode 100644 index 0d62b50b4..000000000 --- a/codeflash/languages/javascript/parse.py +++ /dev/null @@ -1,458 +0,0 @@ -"""Jest/Vitest JUnit XML parsing for JavaScript/TypeScript tests. - -This module handles parsing of JUnit XML test results produced by Jest and Vitest -test runners. It extracts test results, timing information, and maps them back -to instrumented test files. -""" - -from __future__ import annotations - -import contextlib -import json -import re -from pathlib import Path -from typing import TYPE_CHECKING - -from junitparser.xunit2 import JUnitXml - -from codeflash.cli_cmds.console import logger -from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType - -if TYPE_CHECKING: - import subprocess - - from codeflash.models.models import TestFiles - from codeflash.verification.verification_utils import TestConfig - - -# Jest timing marker patterns (from codeflash-jest-helper.js console.log output) -# Format: !$######testName:testName:funcName:loopIndex:lineId######$! (start) -# Format: !######testName:testName:funcName:loopIndex:lineId:durationNs######! (end) -jest_start_pattern = re.compile(r"!\$######([^:]+):([^:]+):([^:]+):([^:]+):([^#]+)######\$!") -jest_end_pattern = re.compile(r"!######([^:]+):([^:]+):([^:]+):([^:]+):([^:]+):(\d+)######!") - - -def _extract_jest_console_output(suite_elem) -> str: - """Extract console output from Jest's JUnit XML system-out element. - - Jest-junit writes console.log output as a JSON array in the testsuite's system-out. - Each entry has: {"message": "...", "origin": "...", "type": "log"} - - Args: - suite_elem: The testsuite lxml element - - Returns: - Concatenated message content from all log entries - - """ - system_out_elem = suite_elem.find("system-out") - if system_out_elem is None or system_out_elem.text is None: - return "" - - raw_content = system_out_elem.text.strip() - if not raw_content: - return "" - - # Jest-junit wraps console output in a JSON array - # Try to parse as JSON first - try: - log_entries = json.loads(raw_content) - if isinstance(log_entries, list): - # Extract message field from each log entry - messages = [] - for entry in log_entries: - if isinstance(entry, dict) and "message" in entry: - messages.append(entry["message"]) - return "\n".join(messages) - except (json.JSONDecodeError, TypeError): - # Not JSON - return as plain text (fallback for pytest-style output) - pass - - return raw_content - - -def parse_jest_test_xml( - test_xml_file_path: Path, - test_files: TestFiles, - test_config: TestConfig, - run_result: subprocess.CompletedProcess | None = None, - parse_func=None, - resolve_test_file_from_class_path=None, -) -> TestResults: - """Parse Jest JUnit XML test results. - - Jest-junit has a different structure than pytest: - - system-out is at the testsuite level (not testcase) - - system-out contains a JSON array of log entries - - Timing markers are in the message field of log entries - - Args: - test_xml_file_path: Path to the Jest JUnit XML file - test_files: TestFiles object with test file information - test_config: Test configuration - run_result: Optional subprocess result for logging - parse_func: XML parser function (injected to avoid circular imports) - resolve_test_file_from_class_path: Function to resolve test file paths (injected) - - Returns: - TestResults containing parsed test invocations - - """ - test_results = TestResults() - - if not test_xml_file_path.exists(): - logger.warning(f"No JavaScript test results for {test_xml_file_path} found.") - return test_results - - # Log file size for debugging - file_size = test_xml_file_path.stat().st_size - logger.debug(f"Jest XML file size: {file_size} bytes at {test_xml_file_path}") - - try: - xml = JUnitXml.fromfile(str(test_xml_file_path), parse_func=parse_func) - logger.debug(f"Successfully parsed Jest JUnit XML from {test_xml_file_path}") - except Exception as e: - logger.warning(f"Failed to parse {test_xml_file_path} as JUnitXml. Exception: {e}") - return test_results - - base_dir = test_config.tests_project_rootdir - logger.debug(f"Jest XML parsing: base_dir={base_dir}, num_test_files={len(test_files.test_files)}") - - # Build lookup from instrumented file path to TestFile for direct matching - # This handles cases where instrumented files are in temp directories - instrumented_path_lookup: dict[str, tuple[Path, TestType]] = {} - for test_file in test_files.test_files: - # Add behavior instrumented file paths - if test_file.instrumented_behavior_file_path: - # Store both the absolute path and resolved path as keys - abs_path = str(test_file.instrumented_behavior_file_path.resolve()) - instrumented_path_lookup[abs_path] = (test_file.instrumented_behavior_file_path, test_file.test_type) - # Also store the string representation in case of minor path differences - instrumented_path_lookup[str(test_file.instrumented_behavior_file_path)] = ( - test_file.instrumented_behavior_file_path, - test_file.test_type, - ) - logger.debug(f"Jest XML lookup: registered {abs_path}") - # Also add benchmarking file paths (perf-only instrumented tests) - if test_file.benchmarking_file_path: - bench_abs_path = str(test_file.benchmarking_file_path.resolve()) - instrumented_path_lookup[bench_abs_path] = (test_file.benchmarking_file_path, test_file.test_type) - instrumented_path_lookup[str(test_file.benchmarking_file_path)] = ( - test_file.benchmarking_file_path, - test_file.test_type, - ) - logger.debug(f"Jest XML lookup: registered benchmark {bench_abs_path}") - - # Also build a filename-only lookup for fallback matching - # This handles cases where JUnit XML has relative paths that don't match absolute paths - # e.g., JUnit has "test/utils__perfinstrumented.test.ts" but lookup has absolute paths - filename_lookup: dict[str, tuple[Path, TestType]] = {} - for test_file in test_files.test_files: - # Add instrumented_behavior_file_path (behavior tests) - if test_file.instrumented_behavior_file_path: - filename = test_file.instrumented_behavior_file_path.name - # Only add if not already present (avoid overwrites in case of duplicate filenames) - if filename not in filename_lookup: - filename_lookup[filename] = (test_file.instrumented_behavior_file_path, test_file.test_type) - logger.debug(f"Jest XML filename lookup: registered {filename}") - # Also add benchmarking_file_path (perf-only tests) - these have different filenames - # e.g., utils__perfonlyinstrumented.test.ts vs utils__perfinstrumented.test.ts - if test_file.benchmarking_file_path: - bench_filename = test_file.benchmarking_file_path.name - if bench_filename not in filename_lookup: - filename_lookup[bench_filename] = (test_file.benchmarking_file_path, test_file.test_type) - logger.debug(f"Jest XML filename lookup: registered benchmark file {bench_filename}") - - # Fallback: if JUnit XML doesn't have system-out, use subprocess stdout directly - global_stdout = "" - if run_result is not None: - try: - global_stdout = run_result.stdout if isinstance(run_result.stdout, str) else run_result.stdout.decode() - # Debug: log if timing markers are found in stdout - if global_stdout: - marker_count = len(jest_start_pattern.findall(global_stdout)) - if marker_count > 0: - logger.debug(f"Found {marker_count} timing start markers in Jest stdout") - else: - logger.debug(f"No timing start markers found in Jest stdout (len={len(global_stdout)})") - except (AttributeError, UnicodeDecodeError): - global_stdout = "" - - suite_count = 0 - testcase_count = 0 - for suite in xml: - suite_count += 1 - # Extract console output from suite-level system-out (Jest specific) - suite_stdout = _extract_jest_console_output(suite._elem) # noqa: SLF001 - - # Fallback: use subprocess stdout if XML system-out is empty - if not suite_stdout and global_stdout: - suite_stdout = global_stdout - - # Parse timing markers from the suite's console output - start_matches = list(jest_start_pattern.finditer(suite_stdout)) - end_matches_dict = {} - for match in jest_end_pattern.finditer(suite_stdout): - # Key: (testName, testName2, funcName, loopIndex, lineId) - key = match.groups()[:5] - end_matches_dict[key] = match - - # Also collect timing markers from testcase-level system-out (Vitest puts output at testcase level) - for tc in suite: - tc_system_out = tc._elem.find("system-out") # noqa: SLF001 - if tc_system_out is not None and tc_system_out.text: - tc_stdout = tc_system_out.text.strip() - logger.debug(f"Vitest testcase system-out found: {len(tc_stdout)} chars, first 200: {tc_stdout[:200]}") - end_marker_count = 0 - for match in jest_end_pattern.finditer(tc_stdout): - key = match.groups()[:5] - end_matches_dict[key] = match - end_marker_count += 1 - if end_marker_count > 0: - logger.debug(f"Found {end_marker_count} END timing markers in testcase system-out") - start_matches.extend(jest_start_pattern.finditer(tc_stdout)) - - for testcase in suite: - testcase_count += 1 - test_class_path = testcase.classname # For Jest, this is the file path - test_name = testcase.name - - if test_name is None: - logger.debug(f"testcase.name is None in Jest XML {test_xml_file_path}, skipping") - continue - - logger.debug(f"Jest XML: processing testcase name={test_name}, classname={test_class_path}") - - # First, try direct lookup in instrumented file paths - # This handles cases where instrumented files are in temp directories - test_file_path = None - test_type = None - - if test_class_path: - # Try exact match with classname (which should be the filepath from jest-junit) - if test_class_path in instrumented_path_lookup: - test_file_path, test_type = instrumented_path_lookup[test_class_path] - else: - # Try resolving the path and matching - try: - resolved_path = str(Path(test_class_path).resolve()) - if resolved_path in instrumented_path_lookup: - test_file_path, test_type = instrumented_path_lookup[resolved_path] - except Exception: - pass - - # If direct lookup failed, try the file attribute - if test_file_path is None: - test_file_name = suite._elem.attrib.get("file") or testcase._elem.attrib.get("file") # noqa: SLF001 - if test_file_name: - if test_file_name in instrumented_path_lookup: - test_file_path, test_type = instrumented_path_lookup[test_file_name] - else: - try: - resolved_path = str(Path(test_file_name).resolve()) - if resolved_path in instrumented_path_lookup: - test_file_path, test_type = instrumented_path_lookup[resolved_path] - except Exception: - pass - - # Fall back to traditional path resolution if direct lookup failed - if test_file_path is None and resolve_test_file_from_class_path is not None: - test_file_path = resolve_test_file_from_class_path(test_class_path, base_dir) - if test_file_path is None: - test_file_name = suite._elem.attrib.get("file") or testcase._elem.attrib.get("file") # noqa: SLF001 - if test_file_name: - test_file_path = base_dir.parent / test_file_name - if not test_file_path.exists(): - test_file_path = base_dir / test_file_name - - # Fallback: try matching by filename only - # This handles when JUnit XML has relative paths like "test/utils__perfinstrumented.test.ts" - # that can't be resolved to absolute paths because they're relative to Jest's CWD, not parse CWD - if test_file_path is None and test_class_path: - # Extract filename from the path (handles both forward and back slashes) - path_filename = Path(test_class_path).name - if path_filename in filename_lookup: - test_file_path, test_type = filename_lookup[path_filename] - logger.debug(f"Jest XML: matched by filename {path_filename}") - - # Also try filename matching on the file attribute if classname matching failed - if test_file_path is None: - test_file_name = suite._elem.attrib.get("file") or testcase._elem.attrib.get("file") # noqa: SLF001 - if test_file_name: - file_attr_filename = Path(test_file_name).name - if file_attr_filename in filename_lookup: - test_file_path, test_type = filename_lookup[file_attr_filename] - logger.debug(f"Jest XML: matched by file attr filename {file_attr_filename}") - - # For Jest tests in monorepos, test files may not exist after cleanup - # but we can still parse results and infer test type from the path - if test_file_path is None: - logger.warning(f"Could not resolve test file for Jest test: {test_class_path}") - continue - - # Get test type if not already set from lookup - if test_type is None and test_file_path.exists(): - test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path) - if test_type is None: - # Infer test type from filename pattern - filename = test_file_path.name - if "__perf_test_" in filename or "_perf_test_" in filename: - test_type = TestType.GENERATED_PERFORMANCE - elif "__unit_test_" in filename or "_unit_test_" in filename: - test_type = TestType.GENERATED_REGRESSION - else: - # Default to GENERATED_REGRESSION for Jest tests - test_type = TestType.GENERATED_REGRESSION - - # For Jest tests, keep the relative file path with extension intact - # (Python uses module_name_from_file_path which strips extensions) - try: - test_module_path = str(test_file_path.relative_to(test_config.tests_project_rootdir)) - except ValueError: - test_module_path = test_file_path.name - result = testcase.is_passed - - # Check for timeout - timed_out = False - if len(testcase.result) >= 1: - message = (testcase.result[0].message or "").lower() - if "timeout" in message or "timed out" in message: - timed_out = True - - # Find matching timing markers for this test - # Jest test names in markers are sanitized by codeflash-jest-helper's sanitizeTestId() - # which replaces: !#: (space) ()[]{}|\/*?^$.+- with underscores - # IMPORTANT: Must match Jest helper's sanitization exactly for marker matching to work - # Pattern from capture.js: /[!#: ()\[\]{}|\\/*?^$.+\-]/g - sanitized_test_name = re.sub(r"[!#: ()\[\]{}|\\/*?^$.+\-]", "_", test_name) - matching_starts = [m for m in start_matches if sanitized_test_name in m.group(2)] - - # For performance tests (capturePerf), there are no START markers - only END markers with duration - # Check for END markers directly if no START markers found - matching_ends_direct = [] - if not matching_starts: - # Look for END markers that match this test (performance test format) - # END marker format: !######module:testName:funcName:loopIndex:invocationId:durationNs######! - for end_key, end_match in end_matches_dict.items(): - # end_key is (module, testName, funcName, loopIndex, invocationId) - if len(end_key) >= 2 and sanitized_test_name in end_key[1]: - matching_ends_direct.append(end_match) - - if not matching_starts and not matching_ends_direct: - # No timing markers found - use JUnit XML time attribute as fallback - # The time attribute is in seconds (e.g., "0.00077875"), convert to nanoseconds - runtime = None - try: - time_attr = testcase._elem.attrib.get("time") # noqa: SLF001 - if time_attr: - time_seconds = float(time_attr) - runtime = int(time_seconds * 1_000_000_000) # Convert seconds to nanoseconds - logger.debug(f"Jest XML: using time attribute for {test_name}: {time_seconds}s = {runtime}ns") - except (ValueError, TypeError) as e: - logger.debug(f"Jest XML: could not parse time attribute: {e}") - - test_results.add( - FunctionTestInvocation( - loop_index=1, - id=InvocationId( - test_module_path=test_module_path, - test_class_name=None, - test_function_name=test_name, - function_getting_tested="", - iteration_id="", - ), - file_name=test_file_path, - runtime=runtime, - test_framework=test_config.test_framework, - did_pass=result, - test_type=test_type, - return_value=None, - timed_out=timed_out, - stdout="", - ) - ) - elif matching_ends_direct: - # Performance test format: process END markers directly (no START markers) - for end_match in matching_ends_direct: - groups = end_match.groups() - # groups: (module, testName, funcName, loopIndex, invocationId, durationNs) - func_name = groups[2] - loop_index = int(groups[3]) if groups[3].isdigit() else 1 - line_id = groups[4] - try: - runtime = int(groups[5]) - except (ValueError, IndexError): - runtime = None - test_results.add( - FunctionTestInvocation( - loop_index=loop_index, - id=InvocationId( - test_module_path=test_module_path, - test_class_name=None, - test_function_name=test_name, - function_getting_tested=func_name, - iteration_id=line_id, - ), - file_name=test_file_path, - runtime=runtime, - test_framework=test_config.test_framework, - did_pass=result, - test_type=test_type, - return_value=None, - timed_out=timed_out, - stdout="", - ) - ) - else: - # Process each timing marker - for match in matching_starts: - groups = match.groups() - # groups: (testName, testName2, funcName, loopIndex, lineId) - func_name = groups[2] - loop_index = int(groups[3]) if groups[3].isdigit() else 1 - line_id = groups[4] - - # Find matching end marker - end_key = groups[:5] - end_match = end_matches_dict.get(end_key) - - runtime = None - if end_match: - # Duration is in the 6th group (index 5) - with contextlib.suppress(ValueError, IndexError): - runtime = int(end_match.group(6)) - test_results.add( - FunctionTestInvocation( - loop_index=loop_index, - id=InvocationId( - test_module_path=test_module_path, - test_class_name=None, - test_function_name=test_name, - function_getting_tested=func_name, - iteration_id=line_id, - ), - file_name=test_file_path, - runtime=runtime, - test_framework=test_config.test_framework, - did_pass=result, - test_type=test_type, - return_value=None, - timed_out=timed_out, - stdout="", - ) - ) - - if not test_results: - logger.info( - f"No Jest test results parsed from {test_xml_file_path} " - f"(found {suite_count} suites, {testcase_count} testcases)" - ) - if run_result is not None: - logger.debug(f"Jest stdout: {run_result.stdout[:1000] if run_result.stdout else 'empty'}") - else: - logger.debug( - f"Jest XML parsing complete: {len(test_results.test_results)} results " - f"from {suite_count} suites, {testcase_count} testcases" - ) - - return test_results diff --git a/codeflash/languages/javascript/replay_test.py b/codeflash/languages/javascript/replay_test.py new file mode 100644 index 000000000..7dc80a4f0 --- /dev/null +++ b/codeflash/languages/javascript/replay_test.py @@ -0,0 +1,340 @@ +"""JavaScript replay test generation. + +This module provides functionality to generate replay tests from traced JavaScript +function calls. Replay tests allow verifying that optimized code produces the same +results as the original code. + +The generated tests can be run with Jest or Vitest, depending on the project's +test framework configuration. +""" + +from __future__ import annotations + +import json +import sqlite3 +import textwrap +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from collections.abc import Generator + + +@dataclass +class JavaScriptFunctionModule: + """Information about a traced JavaScript function for replay test generation.""" + + function_name: str + file_name: Path + module_name: str + class_name: Optional[str] = None + line_no: Optional[int] = None + + +def get_next_arg_and_return( + trace_file: str, function_name: str, file_name: str, class_name: Optional[str] = None, num_to_get: int = 25 +) -> Generator[Any]: + """Get traced function arguments from the database. + + This mirrors the Python version in codeflash/tracing/replay_test.py. + + Args: + trace_file: Path to the trace SQLite database. + function_name: Name of the function. + file_name: Path to the source file. + class_name: Optional class name for methods. + num_to_get: Maximum number of traces to retrieve. + + Yields: + Serialized argument data for each traced call. + + """ + db = sqlite3.connect(trace_file) + cur = db.cursor() + + # Try the new schema first (function_calls table) + try: + cur.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = {row[0] for row in cur.fetchall()} + + if "function_calls" in tables: + if class_name: + cursor = cur.execute( + "SELECT args FROM function_calls WHERE function = ? AND filename = ? AND classname = ? AND type = 'call' ORDER BY time_ns ASC LIMIT ?", + (function_name, file_name, class_name, num_to_get), + ) + else: + cursor = cur.execute( + "SELECT args FROM function_calls WHERE function = ? AND filename = ? AND type = 'call' ORDER BY time_ns ASC LIMIT ?", + (function_name, file_name, num_to_get), + ) + + while (val := cursor.fetchone()) is not None: + # args is stored as JSON or binary blob + args_data = val[0] + if isinstance(args_data, bytes): + yield args_data + else: + yield args_data + + elif "traces" in tables: + # Legacy schema + if class_name: + cursor = cur.execute( + "SELECT args FROM traces WHERE function = ? AND file = ? ORDER BY id ASC LIMIT ?", + (function_name, file_name, num_to_get), + ) + else: + cursor = cur.execute( + "SELECT args FROM traces WHERE function = ? AND file = ? ORDER BY id ASC LIMIT ?", + (function_name, file_name, num_to_get), + ) + + while (val := cursor.fetchone()) is not None: + yield val[0] + + finally: + db.close() + + +def get_function_alias(module: str, function_name: str, class_name: Optional[str] = None) -> str: + """Generate a unique alias for a function import. + + Args: + module: Module path. + function_name: Function name. + class_name: Optional class name. + + Returns: + A valid JavaScript identifier for the function. + + """ + import re + + # Normalize module path to valid identifier + module_alias = re.sub(r"[^a-zA-Z0-9]", "_", module).strip("_") + + if class_name: + return f"{module_alias}_{class_name}_{function_name}" + return f"{module_alias}_{function_name}" + + +def create_javascript_replay_test( + trace_file: str, + functions: list[JavaScriptFunctionModule], + max_run_count: int = 100, + framework: str = "jest", + project_root: Optional[Path] = None, +) -> str: + """Generate a JavaScript replay test file from traced function calls. + + This mirrors the Python version in codeflash/tracing/replay_test.py but + generates JavaScript test code for Jest or Vitest. + + Args: + trace_file: Path to the trace SQLite database. + functions: List of functions to generate tests for. + max_run_count: Maximum number of test cases per function. + framework: Test framework ('jest' or 'vitest'). + project_root: Project root for calculating relative imports. + + Returns: + Generated test file content as a string. + + """ + is_vitest = framework.lower() == "vitest" + + # Build imports section + imports = [] + + if is_vitest: + imports.append("import { describe, test } from 'vitest';") + + imports.append("const { getNextArg } = require('codeflash/replay');") + imports.append("") + + # Build function imports + for func in functions: + if func.function_name in ("__init__", "constructor"): + # Skip constructors + continue + + alias = get_function_alias(func.module_name, func.function_name, func.class_name) + + if func.class_name: + imports.append(f"const {{ {func.class_name}: {alias}_class }} = require('./{func.module_name}');") + else: + imports.append(f"const {{ {func.function_name}: {alias} }} = require('./{func.module_name}');") + + imports.append("") + + # Metadata + functions_to_test = [f.function_name for f in functions if f.function_name not in ("__init__", "constructor")] + metadata = f"""const traceFilePath = '{trace_file}'; +const functions = {json.dumps(functions_to_test)}; +""" + + # Build test cases + test_cases = [] + + for func in functions: + if func.function_name in ("__init__", "constructor"): + continue + + alias = get_function_alias(func.module_name, func.function_name, func.class_name) + test_name = f"{func.class_name}.{func.function_name}" if func.class_name else func.function_name + + if func.class_name: + # Method test - need to instantiate the class + class_arg = f"'{func.class_name}'" + test_body = textwrap.dedent(f""" +describe('Replay: {test_name}', () => {{ + const traces = getNextArg(traceFilePath, '{func.function_name}', '{func.file_name.as_posix()}', {max_run_count}, {class_arg}); + + test.each(traces.map((args, i) => [i, args]))('call %i', (index, args) => {{ + // For instance methods, we need to create an instance + // The traced args may include 'this' context as first argument + const instance = new {alias}_class(); + instance.{func.function_name}(...args); + }}); +}}); +""") + else: + # Regular function test + test_body = textwrap.dedent(f""" +describe('Replay: {test_name}', () => {{ + const traces = getNextArg(traceFilePath, '{func.function_name}', '{func.file_name.as_posix()}', {max_run_count}); + + test.each(traces.map((args, i) => [i, args]))('call %i', (index, args) => {{ + {alias}(...args); + }}); +}}); +""") + + test_cases.append(test_body) + + # Combine all parts + return "\n".join( + [ + "// Auto-generated replay test by Codeflash", + "// Do not edit this file directly", + "", + *imports, + metadata, + *test_cases, + ] + ) + + +def get_traced_functions_from_db(trace_file: Path) -> list[JavaScriptFunctionModule]: + """Get list of functions that were traced from the database. + + Args: + trace_file: Path to trace database. + + Returns: + List of traced function information. + + """ + if not trace_file.exists(): + return [] + + try: + conn = sqlite3.connect(trace_file) + cursor = conn.cursor() + + # Check schema + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = {row[0] for row in cursor.fetchall()} + + functions = [] + + if "function_calls" in tables: + cursor.execute( + "SELECT DISTINCT function, filename, classname, line_number FROM function_calls WHERE type = 'call'" + ) + for row in cursor.fetchall(): + func_name = row[0] + file_name = row[1] + class_name = row[2] + line_number = row[3] + + # Calculate module path from filename + module_path = file_name.replace("\\", "/").replace(".js", "").replace(".ts", "") + module_path = module_path.removeprefix("./") + + functions.append( + JavaScriptFunctionModule( + function_name=func_name, + file_name=Path(file_name), + module_name=module_path, + class_name=class_name, + line_no=line_number, + ) + ) + + elif "traces" in tables: + # Legacy schema + cursor.execute("SELECT DISTINCT function, file FROM traces") + for row in cursor.fetchall(): + func_name = row[0] + file_name = row[1] + + module_path = file_name.replace("\\", "/").replace(".js", "").replace(".ts", "") + module_path = module_path.removeprefix("./") + + functions.append( + JavaScriptFunctionModule( + function_name=func_name, file_name=Path(file_name), module_name=module_path + ) + ) + + conn.close() + return functions + + except Exception: + return [] + + +def create_replay_test_file( + trace_file: Path, + output_path: Path, + framework: str = "jest", + max_run_count: int = 100, + project_root: Optional[Path] = None, +) -> Optional[Path]: + """Generate a replay test file from a trace database. + + This is the main entry point for creating JavaScript replay tests. + + Args: + trace_file: Path to the trace SQLite database. + output_path: Path to write the test file. + framework: Test framework ('jest' or 'vitest'). + max_run_count: Maximum number of test cases per function. + project_root: Project root for calculating relative imports. + + Returns: + Path to generated test file, or None if generation failed. + + """ + functions = get_traced_functions_from_db(trace_file) + + if not functions: + return None + + content = create_javascript_replay_test( + trace_file=str(trace_file), + functions=functions, + max_run_count=max_run_count, + framework=framework, + project_root=project_root, + ) + + try: + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(content) + return output_path + except Exception: + return None diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index 17c3b1021..ee75add01 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -1585,8 +1585,9 @@ def instrument_for_behavior( ) -> str: """Add behavior instrumentation to capture inputs/outputs. - For JavaScript, this wraps functions to capture their arguments - and return values. + For JavaScript, instrumentation is handled at runtime by the Babel tracer plugin + (babel-tracer-plugin.js) via trace-runner.js. This method returns the source + unchanged since no source-level transformation is needed. Args: source: Source code to instrument. @@ -1594,21 +1595,11 @@ def instrument_for_behavior( output_file: Optional output file for traces. Returns: - Instrumented source code. + Source code unchanged (Babel handles instrumentation at runtime). """ - if not functions: - return source - - from codeflash.languages.javascript.tracer import JavaScriptTracer - - # Use first function's file path if output_file not specified - if output_file is None: - file_path = functions[0].file_path - output_file = file_path.parent / ".codeflash" / "traces.db" - - tracer = JavaScriptTracer(output_file) - return tracer.instrument_source(source, functions[0].file_path, list(functions)) + # JavaScript tracing is done at runtime via Babel plugin, not source transformation + return source def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str: """Add timing instrumentation to test code. diff --git a/codeflash/languages/javascript/tracer.py b/codeflash/languages/javascript/tracer.py index 2f5791ee0..7fcbe6822 100644 --- a/codeflash/languages/javascript/tracer.py +++ b/codeflash/languages/javascript/tracer.py @@ -1,35 +1,58 @@ """Function tracing instrumentation for JavaScript. -This module provides functionality to wrap JavaScript functions to capture their -inputs, outputs, and execution behavior. This is used for generating replay tests -and verifying optimization correctness. +This module provides functionality to parse JavaScript function traces and generate +replay tests. Tracing is performed via Babel AST transformation using the +babel-tracer-plugin.js and trace-runner.js in the npm package. + +The tracer uses Babel plugin for AST transformation which: +- Works with both CommonJS and ESM +- Handles async functions, arrow functions, methods correctly +- Preserves source maps and formatting + +Database Schema (matches Python tracer): +- function_calls: Main trace data (type, function, classname, filename, line_number, time_ns, args) +- metadata: Key-value metadata about the trace session """ from __future__ import annotations import json import logging +import re import sqlite3 -from typing import TYPE_CHECKING, Any +import textwrap +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: from pathlib import Path - from codeflash.discovery.functions_to_optimize import FunctionToOptimize - logger = logging.getLogger(__name__) +@dataclass +class JavaScriptFunctionInfo: + """Information about a traced JavaScript function.""" + + function_name: str + file_name: str + module_path: str + class_name: Optional[str] = None + line_number: Optional[int] = None + + class JavaScriptTracer: - """Instruments JavaScript code to capture function inputs and outputs. + """Parses JavaScript function traces and generates replay tests. - Similar to Python's tracing system, this wraps functions to record: - - Input arguments - - Return values - - Exceptions thrown - - Execution time + Tracing is performed via Babel AST transformation (trace-runner.js). + This class handles: + - Parsing trace results from SQLite database + - Extracting traced function information + - Generating replay test files for Jest/Vitest """ + SCHEMA_VERSION = "1.0.0" + def __init__(self, output_db: Path) -> None: """Initialize the tracer. @@ -38,322 +61,15 @@ def __init__(self, output_db: Path) -> None: """ self.output_db = output_db - self.tracer_var = "__codeflash_tracer__" - - def instrument_source(self, source: str, file_path: Path, functions: list[FunctionToOptimize]) -> str: - """Instrument JavaScript source code with function tracing. - - Wraps specified functions to capture their inputs and outputs. - - Args: - source: Original JavaScript source code. - file_path: Path to the source file. - functions: List of functions to instrument. - - Returns: - Instrumented source code with tracing. - - """ - if not functions: - return source - - # Add tracer initialization at the top - tracer_init = self._generate_tracer_init() - - # Add instrumentation to each function - lines = source.splitlines(keepends=True) - - # Process functions in reverse order to preserve line numbers - for func in sorted(functions, key=lambda f: f.starting_line, reverse=True): - instrumented = self._instrument_function(func, lines, file_path) - start_idx = func.starting_line - 1 - end_idx = func.ending_line - lines = lines[:start_idx] + instrumented + lines[end_idx:] - - instrumented_source = "".join(lines) - - # Add tracer save at the end - tracer_save = self._generate_tracer_save() - - return tracer_init + "\n" + instrumented_source + "\n" + tracer_save - - def _generate_tracer_init(self) -> str: - """Generate JavaScript code for tracer initialization.""" - return f""" -// Codeflash function tracer initialization -const {self.tracer_var} = {{ - traces: [], - callId: 0, - - serialize: function(value) {{ - try {{ - // Handle special cases - if (value === undefined) return {{ __type__: 'undefined' }}; - if (value === null) return null; - if (typeof value === 'function') return {{ __type__: 'function', name: value.name }}; - if (typeof value === 'symbol') return {{ __type__: 'symbol', value: value.toString() }}; - if (value instanceof Error) return {{ - __type__: 'error', - name: value.name, - message: value.message, - stack: value.stack - }}; - if (typeof value === 'bigint') return {{ __type__: 'bigint', value: value.toString() }}; - if (value instanceof Date) return {{ __type__: 'date', value: value.toISOString() }}; - if (value instanceof RegExp) return {{ __type__: 'regexp', value: value.toString() }}; - if (value instanceof Map) return {{ - __type__: 'map', - value: Array.from(value.entries()).map(([k, v]) => [this.serialize(k), this.serialize(v)]) - }}; - if (value instanceof Set) return {{ - __type__: 'set', - value: Array.from(value).map(v => this.serialize(v)) - }}; - - // Handle circular references with a simple check - return JSON.parse(JSON.stringify(value)); - }} catch (e) {{ - return {{ __type__: 'unserializable', error: e.message }}; - }} - }}, - - wrap: function(originalFunc, funcName, filePath) {{ - const self = this; - - if (originalFunc.constructor.name === 'AsyncFunction') {{ - return async function(...args) {{ - const callId = self.callId++; - const start = process.hrtime.bigint(); - let result, error; - - try {{ - result = await originalFunc.apply(this, args); - }} catch (e) {{ - error = e; - }} - - const end = process.hrtime.bigint(); - - self.traces.push({{ - call_id: callId, - function: funcName, - file: filePath, - args: args.map(a => self.serialize(a)), - result: error ? null : self.serialize(result), - error: error ? self.serialize(error) : null, - runtime_ns: (end - start).toString(), - timestamp: Date.now() - }}); - - if (error) throw error; - return result; - }}; - }} - - return function(...args) {{ - const callId = self.callId++; - const start = process.hrtime.bigint(); - let result, error; - - try {{ - result = originalFunc.apply(this, args); - }} catch (e) {{ - error = e; - }} - - const end = process.hrtime.bigint(); - - self.traces.push({{ - call_id: callId, - function: funcName, - file: filePath, - args: args.map(a => self.serialize(a)), - result: error ? null : self.serialize(result), - error: error ? self.serialize(error) : null, - runtime_ns: (end - start).toString(), - timestamp: Date.now() - }}); - - if (error) throw error; - return result; - }}; - }}, - - saveToDb: function() {{ - const sqlite3 = require('sqlite3').verbose(); - const fs = require('fs'); - const path = require('path'); - - const dbPath = '{self.output_db.as_posix()}'; - const dbDir = path.dirname(dbPath); - - if (!fs.existsSync(dbDir)) {{ - fs.mkdirSync(dbDir, {{ recursive: true }}); - }} - - const db = new sqlite3.Database(dbPath); - - db.serialize(() => {{ - // Create table - db.run(` - CREATE TABLE IF NOT EXISTS traces ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - call_id INTEGER, - function TEXT, - file TEXT, - args TEXT, - result TEXT, - error TEXT, - runtime_ns TEXT, - timestamp INTEGER - ) - `); - - // Insert traces - const stmt = db.prepare(` - INSERT INTO traces (call_id, function, file, args, result, error, runtime_ns, timestamp) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - `); - - for (const trace of this.traces) {{ - stmt.run( - trace.call_id, - trace.function, - trace.file, - JSON.stringify(trace.args), - JSON.stringify(trace.result), - JSON.stringify(trace.error), - trace.runtime_ns, - trace.timestamp - ); - }} - - stmt.finalize(); - }}); - - db.close(); - }}, - - saveToJson: function() {{ - const fs = require('fs'); - const path = require('path'); - - const jsonPath = '{self.output_db.with_suffix(".json").as_posix()}'; - const jsonDir = path.dirname(jsonPath); - - if (!fs.existsSync(jsonDir)) {{ - fs.mkdirSync(jsonDir, {{ recursive: true }}); - }} - - fs.writeFileSync(jsonPath, JSON.stringify(this.traces, null, 2)); - }} -}}; -""" - - def _generate_tracer_save(self) -> str: - """Generate JavaScript code to save tracer results.""" - return f""" -// Save tracer results on process exit -process.on('exit', () => {{ - try {{ - {self.tracer_var}.saveToJson(); - // Try SQLite, but don't fail if sqlite3 is not installed - try {{ - {self.tracer_var}.saveToDb(); - }} catch (e) {{ - // SQLite not available, JSON is sufficient - }} - }} catch (e) {{ - console.error('Failed to save traces:', e); - }} -}}); -""" - - def _instrument_function(self, func: FunctionToOptimize, lines: list[str], file_path: Path) -> list[str]: - """Instrument a single function with tracing. - - Args: - func: Function to instrument. - lines: Source lines. - file_path: Path to source file. - - Returns: - Instrumented function lines. - - """ - func_lines = lines[func.starting_line - 1 : func.ending_line] - func_text = "".join(func_lines) - - # Detect function pattern - func_name = func.function_name - is_arrow = "=>" in func_text.split("\n")[0] - is_method = func.is_method - is_async = func.is_async - - # Generate wrapper code based on function type - if is_arrow: - # For arrow functions: const foo = (a, b) => { ... } - # Replace with: const foo = __codeflash_tracer__.wrap((a, b) => { ... }, 'foo', 'file.js') - return self._wrap_arrow_function(func_lines, func_name, file_path) - if is_method: - # For methods: methodName(a, b) { ... } - # Wrap the method body - return self._wrap_method(func_lines, func_name, file_path, is_async) - # For regular functions: function foo(a, b) { ... } - # Wrap the entire function - return self._wrap_regular_function(func_lines, func_name, file_path, is_async) - - def _wrap_arrow_function(self, func_lines: list[str], func_name: str, file_path: Path) -> list[str]: - """Wrap an arrow function with tracing.""" - # Find the assignment line - first_line = func_lines[0] - indent = len(first_line) - len(first_line.lstrip()) - indent_str = " " * indent - - # Insert wrapper call - func_text = "".join(func_lines).rstrip() - - # Find the '=' and wrap everything after it - if "=" in func_text: - parts = func_text.split("=", 1) - wrapped = f"{parts[0]}= {self.tracer_var}.wrap({parts[1]}, '{func_name}', '{file_path.as_posix()}');\n" - return [wrapped] - - return func_lines - - def _wrap_method(self, func_lines: list[str], func_name: str, file_path: Path, is_async: bool) -> list[str]: - """Wrap a class method with tracing.""" - # For methods, we wrap by reassigning them after definition - # This is complex, so for now we'll return unwrapped - # TODO: Implement method wrapping - logger.warning("Method wrapping not fully implemented for %s", func_name) - return func_lines - - def _wrap_regular_function( - self, func_lines: list[str], func_name: str, file_path: Path, is_async: bool - ) -> list[str]: - """Wrap a regular function declaration with tracing.""" - # Replace: function foo(a, b) { ... } - # With: const __original_foo = function foo(a, b) { ... }; const foo = __codeflash_tracer__.wrap(__original_foo, 'foo', 'file.js'); - - func_text = "".join(func_lines).rstrip() - first_line = func_lines[0] - indent = len(first_line) - len(first_line.lstrip()) - indent_str = " " * indent - - wrapped = ( - f"{indent_str}const __original_{func_name}__ = {func_text};\n" - f"{indent_str}const {func_name} = {self.tracer_var}.wrap(__original_{func_name}__, '{func_name}', '{file_path.as_posix()}');\n" - ) - - return [wrapped] @staticmethod def parse_results(trace_file: Path) -> list[dict[str, Any]]: """Parse tracing results from output file. + Supports both the new function_calls schema and legacy traces schema. + Args: - trace_file: Path to traces JSON file. + trace_file: Path to traces file (SQLite or JSON). Returns: List of trace records. @@ -364,35 +80,59 @@ def parse_results(trace_file: Path) -> list[dict[str, Any]]: if json_file.exists(): try: with json_file.open("r") as f: - return json.load(f) + data: list[dict[str, Any]] = json.load(f) + return data except Exception as e: logger.exception("Failed to parse trace JSON: %s", e) return [] - # Try SQLite database if not trace_file.exists(): return [] try: conn = sqlite3.connect(trace_file) cursor = conn.cursor() - cursor.execute("SELECT * FROM traces ORDER BY id") + + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = {row[0] for row in cursor.fetchall()} traces = [] - for row in cursor.fetchall(): - traces.append( - { - "id": row[0], - "call_id": row[1], - "function": row[2], - "file": row[3], - "args": json.loads(row[4]), - "result": json.loads(row[5]), - "error": json.loads(row[6]) if row[6] != "null" else None, - "runtime_ns": int(row[7]), - "timestamp": row[8], - } + + if "function_calls" in tables: + cursor.execute( + "SELECT type, function, classname, filename, line_number, " + "last_frame_address, time_ns, args FROM function_calls ORDER BY time_ns" ) + for row in cursor.fetchall(): + traces.append( + { + "type": row[0], + "function": row[1], + "classname": row[2], + "filename": row[3], + "line_number": row[4], + "last_frame_address": row[5], + "time_ns": row[6], + "args": json.loads(row[7]) if row[7] else [], + } + ) + elif "traces" in tables: + # Legacy schema + cursor.execute("SELECT * FROM traces ORDER BY id") + for row in cursor.fetchall(): + traces.append( + { + "id": row[0], + "call_id": row[1], + "function": row[2], + "file": row[3], + "args": json.loads(row[4]) if row[4] else [], + "result": json.loads(row[5]) if row[5] else None, + "error": json.loads(row[6]) if row[6] and row[6] != "null" else None, + "runtime_ns": int(row[7]) if row[7] else 0, + "timestamp": row[8] if len(row) > 8 else None, + } + ) conn.close() return traces @@ -400,3 +140,168 @@ def parse_results(trace_file: Path) -> list[dict[str, Any]]: except Exception as e: logger.exception("Failed to parse trace database: %s", e) return [] + + @staticmethod + def get_traced_functions(trace_file: Path) -> list[JavaScriptFunctionInfo]: + """Get list of functions that were traced. + + Args: + trace_file: Path to trace database. + + Returns: + List of traced function information. + + """ + if not trace_file.exists(): + return [] + + try: + conn = sqlite3.connect(trace_file) + cursor = conn.cursor() + + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = {row[0] for row in cursor.fetchall()} + + functions = [] + + if "function_calls" in tables: + cursor.execute( + "SELECT DISTINCT function, filename, classname, line_number FROM function_calls WHERE type = 'call'" + ) + for row in cursor.fetchall(): + func_name = row[0] + file_name = row[1] + class_name = row[2] + line_number = row[3] + + module_path = file_name.replace("\\", "/").replace(".js", "").replace(".ts", "") + module_path = module_path.removeprefix("./") + + functions.append( + JavaScriptFunctionInfo( + function_name=func_name, + file_name=file_name, + module_path=module_path, + class_name=class_name, + line_number=line_number, + ) + ) + + conn.close() + return functions + + except Exception as e: + logger.exception("Failed to get traced functions: %s", e) + return [] + + def create_replay_test( + self, + trace_file: Path, + output_path: Path, + framework: str = "jest", + max_run_count: int = 100, + project_root: Optional[Path] = None, + ) -> Optional[str]: + """Generate a replay test file from traced function calls. + + Args: + trace_file: Path to the trace database. + output_path: Path to write the test file. + framework: Test framework ('jest' or 'vitest'). + max_run_count: Maximum number of test cases per function. + project_root: Project root for calculating relative imports. + + Returns: + Path to generated test file, or None if generation failed. + + """ + functions = self.get_traced_functions(trace_file) + if not functions: + logger.warning("No traced functions found in %s", trace_file) + return None + + is_vitest = framework.lower() == "vitest" + + imports = [] + if is_vitest: + imports.append("import { describe, test } from 'vitest';") + + imports.append("const { getNextArg } = require('codeflash/replay');") + imports.append("") + + for func in functions: + alias = self._get_function_alias(func.module_path, func.function_name, func.class_name) + if func.class_name: + imports.append(f"const {{ {func.class_name}: {alias}_class }} = require('./{func.module_path}');") + else: + imports.append(f"const {{ {func.function_name}: {alias} }} = require('./{func.module_path}');") + + imports.append("") + + trace_path = trace_file.as_posix() + metadata = [ + f"const traceFilePath = '{trace_path}';", + f"const functions = {json.dumps([f.function_name for f in functions])};", + "", + ] + + test_cases = [] + for func in functions: + alias = self._get_function_alias(func.module_path, func.function_name, func.class_name) + test_name = f"{func.class_name}.{func.function_name}" if func.class_name else func.function_name + class_arg = f"'{func.class_name}'" if func.class_name else "null" + + if func.class_name: + test_cases.append( + textwrap.dedent(f""" +describe('Replay: {test_name}', () => {{ + const traces = getNextArg(traceFilePath, '{func.function_name}', '{func.file_name}', {max_run_count}, {class_arg}); + + test.each(traces.map((args, i) => [i, args]))('call %i', (index, args) => {{ + const instance = new {alias}_class(); + instance.{func.function_name}(...args); + }}); +}}); +""") + ) + else: + test_cases.append( + textwrap.dedent(f""" +describe('Replay: {test_name}', () => {{ + const traces = getNextArg(traceFilePath, '{func.function_name}', '{func.file_name}', {max_run_count}); + + test.each(traces.map((args, i) => [i, args]))('call %i', (index, args) => {{ + {alias}(...args); + }}); +}}); +""") + ) + + content = "\n".join( + [ + "// Auto-generated replay test by Codeflash", + "// Do not edit this file directly", + "", + *imports, + *metadata, + *test_cases, + ] + ) + + try: + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(content) + logger.info("Generated replay test: %s", output_path) + return str(output_path) + except Exception as e: + logger.exception("Failed to write replay test: %s", e) + return None + + @staticmethod + def _get_function_alias(module_path: str, function_name: str, class_name: Optional[str] = None) -> str: + """Create a function alias for imports.""" + module_alias = re.sub(r"[^a-zA-Z0-9]", "_", module_path).strip("_") + + if class_name: + return f"{module_alias}_{class_name}_{function_name}" + return f"{module_alias}_{function_name}" diff --git a/codeflash/languages/javascript/tracer_runner.py b/codeflash/languages/javascript/tracer_runner.py new file mode 100644 index 000000000..8a1cacb49 --- /dev/null +++ b/codeflash/languages/javascript/tracer_runner.py @@ -0,0 +1,339 @@ +"""JavaScript tracer runner. + +This module provides functionality to run JavaScript code with function tracing +enabled. It spawns a Node.js subprocess with the trace-runner.js script and +generates replay tests after tracing completes. + +The tracer supports two modes: +1. Script mode: Trace a specific JavaScript file +2. Test mode: Trace tests running under Jest or Vitest + +Usage from CLI: + codeflash trace --language javascript script.js + codeflash trace --language javascript --jest --testPathPattern=mytest +""" + +from __future__ import annotations + +import json +import logging +import os +import shutil +import subprocess +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from argparse import Namespace + +logger = logging.getLogger(__name__) + + +def find_node_executable() -> Optional[Path]: + """Find the Node.js executable. + + Returns: + Path to node executable, or None if not found. + + """ + # Try common locations + node_path = shutil.which("node") + if node_path: + return Path(node_path) + + # Try npx as fallback + npx_path = shutil.which("npx") + if npx_path: + return Path(npx_path) + + return None + + +def find_trace_runner() -> Optional[Path]: + """Find the trace-runner.js script. + + Returns: + Path to trace-runner.js, or None if not found. + + """ + # First, try to find it in the installed codeflash npm package + # Check common node_modules locations + cwd = Path.cwd() + + # Check project-local node_modules + local_path = cwd / "node_modules" / "codeflash" / "runtime" / "trace-runner.js" + if local_path.exists(): + return local_path + + # Check global npm packages + try: + result = subprocess.run(["npm", "root", "-g"], capture_output=True, text=True, check=True) + global_modules = Path(result.stdout.strip()) + global_path = global_modules / "codeflash" / "runtime" / "trace-runner.js" + if global_path.exists(): + return global_path + except Exception: + pass + + # Fall back to the bundled version in the Python package + bundled_path = Path(__file__).parent.parent.parent.parent / "packages" / "codeflash" / "runtime" / "trace-runner.js" + if bundled_path.exists(): + return bundled_path + + return None + + +def run_javascript_tracer(args: Namespace, config: dict[str, Any], project_root: Path) -> dict[str, Any]: + """Run JavaScript code with function tracing enabled. + + Args: + args: Command line arguments. + config: Project configuration. + project_root: Project root directory. + + Returns: + Dictionary with tracing results: + - success: Whether tracing succeeded + - trace_file: Path to trace database + - replay_test_file: Path to generated replay test (if any) + - error: Error message (if failed) + + """ + result: dict[str, Any] = {"success": False, "trace_file": None, "replay_test_file": None, "error": None} + + # Find Node.js + node_path = find_node_executable() + if not node_path: + result["error"] = "Node.js not found. Please install Node.js to use JavaScript tracing." + logger.error(result["error"]) + return result + + # Find trace runner + trace_runner_path = find_trace_runner() + if not trace_runner_path: + result["error"] = "trace-runner.js not found. Please install the codeflash npm package." + logger.error(result["error"]) + return result + + # Determine output paths + outfile = getattr(args, "outfile", None) or "codeflash.trace.sqlite" + trace_file = Path(outfile).resolve() + + # Build environment + env = os.environ.copy() + env["CODEFLASH_TRACE_DB"] = str(trace_file) + env["CODEFLASH_PROJECT_ROOT"] = str(project_root) + + # Set max function count + max_count = getattr(args, "max_function_count", 256) + env["CODEFLASH_MAX_FUNCTION_COUNT"] = str(max_count) + + # Set timeout if specified + timeout = getattr(args, "tracer_timeout", None) + if timeout: + env["CODEFLASH_TRACER_TIMEOUT"] = str(timeout) + + # Set functions to trace if specified + only_functions = getattr(args, "only_functions", None) + if only_functions: + env["CODEFLASH_FUNCTIONS"] = json.dumps(only_functions) + + # Build command + cmd = [str(node_path), str(trace_runner_path)] + + # Add trace runner options + cmd.extend(["--trace-db", str(trace_file)]) + cmd.extend(["--project-root", str(project_root)]) + + if max_count: + cmd.extend(["--max-function-count", str(max_count)]) + + if timeout: + cmd.extend(["--timeout", str(timeout)]) + + if only_functions: + cmd.extend(["--functions", json.dumps(only_functions)]) + + # Determine mode and add appropriate flags + is_module = getattr(args, "module", False) + script_args = [] + + # Get the remaining arguments after parsing + if hasattr(args, "script_args"): + script_args = args.script_args + elif hasattr(args, "unknown_args"): + script_args = args.unknown_args + + if is_module and script_args and script_args[0] == "jest": + cmd.append("--jest") + cmd.append("--") + cmd.extend(script_args[1:]) + elif is_module and script_args and script_args[0] == "vitest": + cmd.append("--vitest") + cmd.append("--") + cmd.extend(script_args[1:]) + elif script_args: + # Regular script mode + cmd.extend(script_args) + + # Run the tracer + logger.info("Running JavaScript tracer: %s", " ".join(cmd)) + + try: + process = subprocess.run(cmd, cwd=project_root, env=env, capture_output=False, check=False) + + if process.returncode != 0: + result["error"] = f"Tracing failed with exit code {process.returncode}" + logger.error(result["error"]) + return result + + except Exception as e: + result["error"] = f"Failed to run tracer: {e}" + logger.exception(result["error"]) + return result + + # Check if trace file was created + if not trace_file.exists(): + result["error"] = f"Trace file not created: {trace_file}" + logger.error(result["error"]) + return result + + result["success"] = True + result["trace_file"] = str(trace_file) + + # Generate replay test if not in trace-only mode + trace_only = getattr(args, "trace_only", False) + if not trace_only: + replay_test_path = generate_replay_test(trace_file=trace_file, project_root=project_root, config=config) + if replay_test_path: + result["replay_test_file"] = str(replay_test_path) + logger.info("Generated replay test: %s", replay_test_path) + + return result + + +def generate_replay_test( + trace_file: Path, project_root: Path, config: dict[str, Any], output_path: Optional[Path] = None +) -> Optional[Path]: + """Generate a replay test file from trace data. + + Args: + trace_file: Path to trace SQLite database. + project_root: Project root directory. + config: Project configuration. + output_path: Optional custom output path. + + Returns: + Path to generated test file, or None if generation failed. + + """ + from codeflash.languages.javascript.replay_test import create_replay_test_file + + # Determine test framework from config or detect from project + framework = detect_test_framework(project_root, config) + + # Determine output path + if output_path is None: + tests_root = config.get("tests_root", "tests") + tests_dir = project_root / tests_root + output_path = tests_dir / "codeflash_replay.test.js" + + return create_replay_test_file( + trace_file=trace_file, + output_path=output_path, + framework=framework, + max_run_count=100, + project_root=project_root, + ) + + +def detect_test_framework(project_root: Path, config: dict[str, Any]) -> str: + """Detect the test framework used by the project. + + Args: + project_root: Project root directory. + config: Project configuration. + + Returns: + Test framework name ('jest' or 'vitest'). + + """ + # Check config first + if "test_framework" in config: + framework: str = config["test_framework"] + return framework + + # Check for vitest config files + vitest_configs = ["vitest.config.js", "vitest.config.ts", "vitest.config.mjs"] + for conf in vitest_configs: + if (project_root / conf).exists(): + return "vitest" + + # Check for jest config files + jest_configs = ["jest.config.js", "jest.config.ts", "jest.config.mjs", "jest.config.json"] + for conf in jest_configs: + if (project_root / conf).exists(): + return "jest" + + # Check package.json for test script + package_json = project_root / "package.json" + if package_json.exists(): + try: + with package_json.open() as f: + pkg = json.load(f) + test_script = pkg.get("scripts", {}).get("test", "") + if "vitest" in test_script: + return "vitest" + if "jest" in test_script: + return "jest" + + # Check dependencies + deps = {**pkg.get("dependencies", {}), **pkg.get("devDependencies", {})} + if "vitest" in deps: + return "vitest" + if "jest" in deps: + return "jest" + except Exception: + pass + + # Default to Jest + return "jest" + + +def check_javascript_tracer_available() -> bool: + """Check if JavaScript tracing is available. + + Returns: + True if all requirements are met for JavaScript tracing. + + """ + # Check for Node.js + if not find_node_executable(): + return False + + # Check for trace runner + if not find_trace_runner(): + return False + + return True + + +def get_tracer_requirements_message() -> str: + """Get a message about tracer requirements. + + Returns: + Human-readable message about what's needed for JavaScript tracing. + + """ + missing = [] + + if not find_node_executable(): + missing.append("Node.js (v18+)") + + if not find_trace_runner(): + missing.append("codeflash npm package (npm install codeflash)") + + if not missing: + return "All requirements met for JavaScript tracing." + + return "Missing requirements for JavaScript tracing:\n- " + "\n- ".join(missing) diff --git a/codeflash/tracer.py b/codeflash/tracer.py index fad0b795d..2d73c4cb7 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -18,7 +18,7 @@ import sys from argparse import ArgumentParser from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from codeflash.cli_cmds.cli import project_root_from_module_root from codeflash.cli_cmds.console import console @@ -32,6 +32,41 @@ from argparse import Namespace +def detect_language_from_config(config: dict[str, Any]) -> str: + """Detect the project language from config or file extensions. + + Args: + config: Project configuration dictionary. + + Returns: + Language identifier ('python', 'javascript', or 'typescript'). + + """ + # Check explicit language in config + if "language" in config: + language: str = config["language"].lower() + return language + + # Check module root for file types + module_root = Path(config.get("module_root", ".")) + if module_root.exists(): + js_files = list(module_root.glob("**/*.js")) + list(module_root.glob("**/*.jsx")) + ts_files = list(module_root.glob("**/*.ts")) + list(module_root.glob("**/*.tsx")) + py_files = list(module_root.glob("**/*.py")) + + # Filter out node_modules + js_files = [f for f in js_files if "node_modules" not in str(f)] + ts_files = [f for f in ts_files if "node_modules" not in str(f)] + + total_js = len(js_files) + len(ts_files) + total_py = len(py_files) + + if total_js > total_py: + return "typescript" if len(ts_files) > len(js_files) else "javascript" + + return "python" + + def main(args: Namespace | None = None) -> ArgumentParser: parser = ArgumentParser(allow_abbrev=False) parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to ", default="codeflash.trace") @@ -60,6 +95,11 @@ def main(args: Namespace | None = None) -> ArgumentParser: parser.add_argument( "--limit", type=int, default=None, help="Limit the number of test files to process (for -m pytest mode)" ) + parser.add_argument( + "--language", + help="Language to trace (python, javascript, typescript). Auto-detected if not specified.", + default=None, + ) if args is not None: parsed_args = args @@ -93,6 +133,16 @@ def main(args: Namespace | None = None) -> ArgumentParser: outfile = parsed_args.outfile config, found_config_path = parse_config_file(parsed_args.codeflash_config) project_root = project_root_from_module_root(Path(config["module_root"]), found_config_path) + + # Detect or use specified language + language = getattr(parsed_args, "language", None) or detect_language_from_config(config) + + # Route to appropriate tracer based on language + if language in ("javascript", "typescript"): + if outfile is None: + outfile = Path("codeflash.trace.sqlite") + return run_javascript_tracer_main(parsed_args, config, project_root, outfile, unknown_args) + if len(unknown_args) > 0: args_dict = { "functions": parsed_args.only_functions, @@ -105,16 +155,17 @@ def main(args: Namespace | None = None) -> ArgumentParser: "module": parsed_args.module, } try: - pytest_splits = [] - test_paths = [] - replay_test_paths = [] + pytest_splits: list[list[str]] = [] + test_paths: list[str] = [] + replay_test_paths: list[str] = [] if parsed_args.module and unknown_args[0] == "pytest": - pytest_splits, test_paths = pytest_split(unknown_args[1:], limit=parsed_args.limit) - if pytest_splits is None or test_paths is None: + split_result = pytest_split(unknown_args[1:], limit=parsed_args.limit) + if split_result[0] is None or split_result[1] is None: console.print(f"❌ Could not find test files in the specified paths: {unknown_args[1:]}") console.print(f"Current working directory: {Path.cwd()}") console.print("Please ensure the test directory exists and contains test files.") sys.exit(1) + pytest_splits, test_paths = split_result if len(pytest_splits) > 1: processes = [] @@ -255,5 +306,89 @@ def main(args: Namespace | None = None) -> ArgumentParser: return parser +def run_javascript_tracer_main( + parsed_args: Namespace, config: dict[str, Any], project_root: Path, outfile: Path, unknown_args: list[str] +) -> ArgumentParser: + """Run the JavaScript tracer. + + Args: + parsed_args: Parsed command line arguments. + config: Project configuration. + project_root: Project root directory. + outfile: Output trace file path. + unknown_args: Remaining command line arguments. + + Returns: + The argument parser. + + """ + from codeflash.languages.javascript.tracer_runner import ( + check_javascript_tracer_available, + get_tracer_requirements_message, + run_javascript_tracer, + ) + + # Check requirements + if not check_javascript_tracer_available(): + console.print(f"[red]{get_tracer_requirements_message()}[/red]") + sys.exit(1) + + # Prepare args for the tracer runner + parsed_args.script_args = unknown_args + + # Run the tracer + console.print("[bold blue]Running JavaScript tracer...[/bold blue]") + result = run_javascript_tracer(parsed_args, config, project_root) + + if not result["success"]: + console.print(f"[red]Tracing failed: {result.get('error', 'Unknown error')}[/red]") + sys.exit(1) + + console.print(f"[green]Trace saved to: {result['trace_file']}[/green]") + + if result.get("replay_test_file"): + console.print(f"[green]Replay test generated: {result['replay_test_file']}[/green]") + + # Run optimization if not trace-only mode + if not parsed_args.trace_only: + from codeflash.cli_cmds.cli import parse_args as cli_parse_args + from codeflash.cli_cmds.cli import process_pyproject_config + from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO + from codeflash.cli_cmds.console import paneled_text + from codeflash.languages import set_current_language + from codeflash.languages.base import Language + from codeflash.telemetry import posthog_cf + from codeflash.telemetry.sentry import init_sentry + + # Set language to JavaScript + set_current_language(Language.JAVASCRIPT) + + sys.argv = ["codeflash", "--replay-test", result["replay_test_file"]] + args = cli_parse_args() + paneled_text( + CODEFLASH_LOGO, + panel_args={"title": "https://codeflash.ai", "expand": False}, + text_args={"style": "bold gold3"}, + ) + + args = process_pyproject_config(args) + args.previous_checkpoint_functions = None + init_sentry(enabled=not args.disable_telemetry, exclude_errors=True) + posthog_cf.initialize_posthog(enabled=not args.disable_telemetry) + + from codeflash.optimization import optimizer + + args.effort = EffortLevel.HIGH.value + optimizer.run_with_args(args) + + # Clean up trace and replay test files + if outfile: + outfile.unlink(missing_ok=True) + Path(result["replay_test_file"]).unlink(missing_ok=True) + + # Return a new parser for API compatibility + return ArgumentParser(allow_abbrev=False) + + if __name__ == "__main__": main() diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index c80a287e5..5836adebc 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -32,10 +32,6 @@ ) from codeflash.verification.coverage_utils import CoverageUtils, JestCoverageUtils -# Import Jest-specific parsing from the JavaScript language module -from codeflash.languages.javascript.parse import jest_end_pattern, jest_start_pattern -from codeflash.languages.javascript.parse import parse_jest_test_xml as _parse_jest_test_xml - if TYPE_CHECKING: import subprocess @@ -56,8 +52,11 @@ def parse_func(file_path: Path) -> XMLParser: start_pattern = re.compile(r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!") end_pattern = re.compile(r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!") -# Jest timing marker patterns are imported from codeflash.languages.javascript.parse -# and re-exported here for backwards compatibility +# Jest timing marker patterns (from codeflash-jest-helper.js console.log output) +# Format: !$######testName:testName:funcName:loopIndex:lineId######$! (start) +# Format: !######testName:testName:funcName:loopIndex:lineId:durationNs######! (end) +jest_start_pattern = re.compile(r"!\$######([^:]+):([^:]+):([^:]+):([^:]+):([^#]+)######\$!") +jest_end_pattern = re.compile(r"!######([^:]+):([^:]+):([^:]+):([^:]+):([^:]+):(\d+)######!") def calculate_function_throughput_from_test_results(test_results: TestResults, function_name: str) -> int: @@ -557,6 +556,430 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes return test_results +def _extract_jest_console_output(suite_elem) -> str: + """Extract console output from Jest's JUnit XML system-out element. + + Jest-junit writes console.log output as a JSON array in the testsuite's system-out. + Each entry has: {"message": "...", "origin": "...", "type": "log"} + + Args: + suite_elem: The testsuite lxml element + + Returns: + Concatenated message content from all log entries + + """ + import json + + system_out_elem = suite_elem.find("system-out") + if system_out_elem is None or system_out_elem.text is None: + return "" + + raw_content = system_out_elem.text.strip() + if not raw_content: + return "" + + # Jest-junit wraps console output in a JSON array + # Try to parse as JSON first + try: + log_entries = json.loads(raw_content) + if isinstance(log_entries, list): + # Extract message field from each log entry + messages = [] + for entry in log_entries: + if isinstance(entry, dict) and "message" in entry: + messages.append(entry["message"]) + return "\n".join(messages) + except (json.JSONDecodeError, TypeError): + # Not JSON - return as plain text (fallback for pytest-style output) + pass + + return raw_content + + +# TODO: {Claude} we need to move to the support directory. +def parse_jest_test_xml( + test_xml_file_path: Path, + test_files: TestFiles, + test_config: TestConfig, + run_result: subprocess.CompletedProcess | None = None, +) -> TestResults: + """Parse Jest JUnit XML test results. + + Jest-junit has a different structure than pytest: + - system-out is at the testsuite level (not testcase) + - system-out contains a JSON array of log entries + - Timing markers are in the message field of log entries + + Args: + test_xml_file_path: Path to the Jest JUnit XML file + test_files: TestFiles object with test file information + test_config: Test configuration + run_result: Optional subprocess result for logging + + Returns: + TestResults containing parsed test invocations + + """ + test_results = TestResults() + + if not test_xml_file_path.exists(): + logger.warning(f"No JavaScript test results for {test_xml_file_path} found.") + return test_results + + # Log file size for debugging + file_size = test_xml_file_path.stat().st_size + logger.debug(f"Jest XML file size: {file_size} bytes at {test_xml_file_path}") + + try: + xml = JUnitXml.fromfile(str(test_xml_file_path), parse_func=parse_func) + logger.debug(f"Successfully parsed Jest JUnit XML from {test_xml_file_path}") + except Exception as e: + logger.warning(f"Failed to parse {test_xml_file_path} as JUnitXml. Exception: {e}") + return test_results + + base_dir = test_config.tests_project_rootdir + logger.debug(f"Jest XML parsing: base_dir={base_dir}, num_test_files={len(test_files.test_files)}") + + # Build lookup from instrumented file path to TestFile for direct matching + # This handles cases where instrumented files are in temp directories + instrumented_path_lookup: dict[str, tuple[Path, TestType]] = {} + for test_file in test_files.test_files: + if test_file.instrumented_behavior_file_path: + # Store both the absolute path and resolved path as keys + abs_path = str(test_file.instrumented_behavior_file_path.resolve()) + instrumented_path_lookup[abs_path] = (test_file.instrumented_behavior_file_path, test_file.test_type) + # Also store the string representation in case of minor path differences + instrumented_path_lookup[str(test_file.instrumented_behavior_file_path)] = ( + test_file.instrumented_behavior_file_path, + test_file.test_type, + ) + logger.debug(f"Jest XML lookup: registered {abs_path}") + # Also add benchmarking file paths (perf-only instrumented tests) + if test_file.benchmarking_file_path: + bench_abs_path = str(test_file.benchmarking_file_path.resolve()) + instrumented_path_lookup[bench_abs_path] = (test_file.benchmarking_file_path, test_file.test_type) + instrumented_path_lookup[str(test_file.benchmarking_file_path)] = ( + test_file.benchmarking_file_path, + test_file.test_type, + ) + logger.debug(f"Jest XML lookup: registered benchmark {bench_abs_path}") + + # Also build a filename-only lookup for fallback matching + # This handles cases where JUnit XML has relative paths that don't match absolute paths + # e.g., JUnit has "test/utils__perfinstrumented.test.ts" but lookup has absolute paths + filename_lookup: dict[str, tuple[Path, TestType]] = {} + for test_file in test_files.test_files: + # Add instrumented_behavior_file_path (behavior tests) + if test_file.instrumented_behavior_file_path: + filename = test_file.instrumented_behavior_file_path.name + # Only add if not already present (avoid overwrites in case of duplicate filenames) + if filename not in filename_lookup: + filename_lookup[filename] = (test_file.instrumented_behavior_file_path, test_file.test_type) + logger.debug(f"Jest XML filename lookup: registered {filename}") + # Also add benchmarking_file_path (perf-only tests) - these have different filenames + # e.g., utils__perfonlyinstrumented.test.ts vs utils__perfinstrumented.test.ts + if test_file.benchmarking_file_path: + bench_filename = test_file.benchmarking_file_path.name + if bench_filename not in filename_lookup: + filename_lookup[bench_filename] = (test_file.benchmarking_file_path, test_file.test_type) + logger.debug(f"Jest XML filename lookup: registered benchmark file {bench_filename}") + + # Fallback: if JUnit XML doesn't have system-out, use subprocess stdout directly + global_stdout = "" + if run_result is not None: + try: + global_stdout = run_result.stdout if isinstance(run_result.stdout, str) else run_result.stdout.decode() + # Debug: log if timing markers are found in stdout + if global_stdout: + marker_count = len(jest_start_pattern.findall(global_stdout)) + if marker_count > 0: + logger.debug(f"Found {marker_count} timing start markers in Jest stdout") + else: + logger.debug(f"No timing start markers found in Jest stdout (len={len(global_stdout)})") + except (AttributeError, UnicodeDecodeError): + global_stdout = "" + + suite_count = 0 + testcase_count = 0 + for suite in xml: + suite_count += 1 + # Extract console output from suite-level system-out (Jest specific) + suite_stdout = _extract_jest_console_output(suite._elem) # noqa: SLF001 + + # Fallback: use subprocess stdout if XML system-out is empty + if not suite_stdout and global_stdout: + suite_stdout = global_stdout + + # Parse timing markers from the suite's console output + start_matches = list(jest_start_pattern.finditer(suite_stdout)) + end_matches_dict = {} + for match in jest_end_pattern.finditer(suite_stdout): + # Key: (testName, testName2, funcName, loopIndex, lineId) + key = match.groups()[:5] + end_matches_dict[key] = match + + # Also collect timing markers from testcase-level system-out (Vitest puts output at testcase level) + for tc in suite: + tc_system_out = tc._elem.find("system-out") # noqa: SLF001 + if tc_system_out is not None and tc_system_out.text: + tc_stdout = tc_system_out.text.strip() + logger.debug(f"Vitest testcase system-out found: {len(tc_stdout)} chars, first 200: {tc_stdout[:200]}") + end_marker_count = 0 + for match in jest_end_pattern.finditer(tc_stdout): + key = match.groups()[:5] + end_matches_dict[key] = match + end_marker_count += 1 + if end_marker_count > 0: + logger.debug(f"Found {end_marker_count} END timing markers in testcase system-out") + start_matches.extend(jest_start_pattern.finditer(tc_stdout)) + + for testcase in suite: + testcase_count += 1 + test_class_path = testcase.classname # For Jest, this is the file path + test_name = testcase.name + + if test_name is None: + logger.debug(f"testcase.name is None in Jest XML {test_xml_file_path}, skipping") + continue + + logger.debug(f"Jest XML: processing testcase name={test_name}, classname={test_class_path}") + + # First, try direct lookup in instrumented file paths + # This handles cases where instrumented files are in temp directories + test_file_path = None + test_type = None + + if test_class_path: + # Try exact match with classname (which should be the filepath from jest-junit) + if test_class_path in instrumented_path_lookup: + test_file_path, test_type = instrumented_path_lookup[test_class_path] + else: + # Try resolving the path and matching + try: + resolved_path = str(Path(test_class_path).resolve()) + if resolved_path in instrumented_path_lookup: + test_file_path, test_type = instrumented_path_lookup[resolved_path] + except Exception: + pass + + # If direct lookup failed, try the file attribute + if test_file_path is None: + test_file_name = suite._elem.attrib.get("file") or testcase._elem.attrib.get("file") # noqa: SLF001 + if test_file_name: + if test_file_name in instrumented_path_lookup: + test_file_path, test_type = instrumented_path_lookup[test_file_name] + else: + try: + resolved_path = str(Path(test_file_name).resolve()) + if resolved_path in instrumented_path_lookup: + test_file_path, test_type = instrumented_path_lookup[resolved_path] + except Exception: + pass + + # Fall back to traditional path resolution if direct lookup failed + if test_file_path is None: + test_file_path = resolve_test_file_from_class_path(test_class_path, base_dir) + if test_file_path is None: + test_file_name = suite._elem.attrib.get("file") or testcase._elem.attrib.get("file") # noqa: SLF001 + if test_file_name: + test_file_path = base_dir.parent / test_file_name + if not test_file_path.exists(): + test_file_path = base_dir / test_file_name + + # Fallback: try matching by filename only + # This handles when JUnit XML has relative paths like "test/utils__perfinstrumented.test.ts" + # that can't be resolved to absolute paths because they're relative to Jest's CWD, not parse CWD + if test_file_path is None and test_class_path: + # Extract filename from the path (handles both forward and back slashes) + path_filename = Path(test_class_path).name + if path_filename in filename_lookup: + test_file_path, test_type = filename_lookup[path_filename] + logger.debug(f"Jest XML: matched by filename {path_filename}") + + # Also try filename matching on the file attribute if classname matching failed + if test_file_path is None: + test_file_name = suite._elem.attrib.get("file") or testcase._elem.attrib.get("file") # noqa: SLF001 + if test_file_name: + file_attr_filename = Path(test_file_name).name + if file_attr_filename in filename_lookup: + test_file_path, test_type = filename_lookup[file_attr_filename] + logger.debug(f"Jest XML: matched by file attr filename {file_attr_filename}") + + # For Jest tests in monorepos, test files may not exist after cleanup + # but we can still parse results and infer test type from the path + if test_file_path is None: + logger.warning(f"Could not resolve test file for Jest test: {test_class_path}") + continue + + # Get test type if not already set from lookup + if test_type is None and test_file_path.exists(): + test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path) + if test_type is None: + # Infer test type from filename pattern + filename = test_file_path.name + if "__perf_test_" in filename or "_perf_test_" in filename: + test_type = TestType.GENERATED_PERFORMANCE + elif "__unit_test_" in filename or "_unit_test_" in filename: + test_type = TestType.GENERATED_REGRESSION + else: + # Default to GENERATED_REGRESSION for Jest tests + test_type = TestType.GENERATED_REGRESSION + + # For Jest tests, keep the relative file path with extension intact + # (Python uses module_name_from_file_path which strips extensions) + try: + test_module_path = str(test_file_path.relative_to(test_config.tests_project_rootdir)) + except ValueError: + test_module_path = test_file_path.name + result = testcase.is_passed + + # Check for timeout + timed_out = False + if len(testcase.result) >= 1: + message = (testcase.result[0].message or "").lower() + if "timeout" in message or "timed out" in message: + timed_out = True + + # Find matching timing markers for this test + # Jest test names in markers are sanitized by codeflash-jest-helper's sanitizeTestId() + # which replaces: !#: (space) ()[]{}|\/*?^$.+- with underscores + # IMPORTANT: Must match Jest helper's sanitization exactly for marker matching to work + # Pattern from capture.js: /[!#: ()\[\]{}|\\/*?^$.+\-]/g + sanitized_test_name = re.sub(r"[!#: ()\[\]{}|\\/*?^$.+\-]", "_", test_name) + matching_starts = [m for m in start_matches if sanitized_test_name in m.group(2)] + + # For performance tests (capturePerf), there are no START markers - only END markers with duration + # Check for END markers directly if no START markers found + matching_ends_direct = [] + if not matching_starts: + # Look for END markers that match this test (performance test format) + # END marker format: !######module:testName:funcName:loopIndex:invocationId:durationNs######! + for end_key, end_match in end_matches_dict.items(): + # end_key is (module, testName, funcName, loopIndex, invocationId) + if len(end_key) >= 2 and sanitized_test_name in end_key[1]: + matching_ends_direct.append(end_match) + + if not matching_starts and not matching_ends_direct: + # No timing markers found - use JUnit XML time attribute as fallback + # The time attribute is in seconds (e.g., "0.00077875"), convert to nanoseconds + runtime = None + try: + time_attr = testcase._elem.attrib.get("time") # noqa: SLF001 + if time_attr: + time_seconds = float(time_attr) + runtime = int(time_seconds * 1_000_000_000) # Convert seconds to nanoseconds + logger.debug(f"Jest XML: using time attribute for {test_name}: {time_seconds}s = {runtime}ns") + except (ValueError, TypeError) as e: + logger.debug(f"Jest XML: could not parse time attribute: {e}") + + test_results.add( + FunctionTestInvocation( + loop_index=1, + id=InvocationId( + test_module_path=test_module_path, + test_class_name=None, + test_function_name=test_name, + function_getting_tested="", + iteration_id="", + ), + file_name=test_file_path, + runtime=runtime, + test_framework=test_config.test_framework, + did_pass=result, + test_type=test_type, + return_value=None, + timed_out=timed_out, + stdout="", + ) + ) + elif matching_ends_direct: + # Performance test format: process END markers directly (no START markers) + for end_match in matching_ends_direct: + groups = end_match.groups() + # groups: (module, testName, funcName, loopIndex, invocationId, durationNs) + func_name = groups[2] + loop_index = int(groups[3]) if groups[3].isdigit() else 1 + line_id = groups[4] + try: + runtime = int(groups[5]) + except (ValueError, IndexError): + runtime = None + test_results.add( + FunctionTestInvocation( + loop_index=loop_index, + id=InvocationId( + test_module_path=test_module_path, + test_class_name=None, + test_function_name=test_name, + function_getting_tested=func_name, + iteration_id=line_id, + ), + file_name=test_file_path, + runtime=runtime, + test_framework=test_config.test_framework, + did_pass=result, + test_type=test_type, + return_value=None, + timed_out=timed_out, + stdout="", + ) + ) + else: + # Process each timing marker + for match in matching_starts: + groups = match.groups() + # groups: (testName, testName2, funcName, loopIndex, lineId) + func_name = groups[2] + loop_index = int(groups[3]) if groups[3].isdigit() else 1 + line_id = groups[4] + + # Find matching end marker + end_key = groups[:5] + end_match = end_matches_dict.get(end_key) + + runtime = None + if end_match: + # Duration is in the 6th group (index 5) + with contextlib.suppress(ValueError, IndexError): + runtime = int(end_match.group(6)) + test_results.add( + FunctionTestInvocation( + loop_index=loop_index, + id=InvocationId( + test_module_path=test_module_path, + test_class_name=None, + test_function_name=test_name, + function_getting_tested=func_name, + iteration_id=line_id, + ), + file_name=test_file_path, + runtime=runtime, + test_framework=test_config.test_framework, + did_pass=result, + test_type=test_type, + return_value=None, + timed_out=timed_out, + stdout="", + ) + ) + + if not test_results: + logger.info( + f"No Jest test results parsed from {test_xml_file_path} " + f"(found {suite_count} suites, {testcase_count} testcases)" + ) + if run_result is not None: + logger.debug(f"Jest stdout: {run_result.stdout[:1000] if run_result.stdout else 'empty'}") + else: + logger.debug( + f"Jest XML parsing complete: {len(test_results.test_results)} results " + f"from {suite_count} suites, {testcase_count} testcases" + ) + + return test_results + + def parse_test_xml( test_xml_file_path: Path, test_files: TestFiles, @@ -565,14 +988,7 @@ def parse_test_xml( ) -> TestResults: # Route to Jest-specific parser for JavaScript/TypeScript tests if is_javascript(): - return _parse_jest_test_xml( - test_xml_file_path, - test_files, - test_config, - run_result, - parse_func=parse_func, - resolve_test_file_from_class_path=resolve_test_file_from_class_path, - ) + return parse_jest_test_xml(test_xml_file_path, test_files, test_config, run_result) test_results = TestResults() # Parse unittest output diff --git a/packages/codeflash/package.json b/packages/codeflash/package.json index dfd3abdf1..8135db4a3 100644 --- a/packages/codeflash/package.json +++ b/packages/codeflash/package.json @@ -6,7 +6,8 @@ "types": "runtime/index.d.ts", "bin": { "codeflash": "./bin/codeflash.js", - "codeflash-setup": "./bin/codeflash-setup.js" + "codeflash-setup": "./bin/codeflash-setup.js", + "codeflash-trace": "./runtime/trace-runner.js" }, "publishConfig": { "access": "public" @@ -32,6 +33,18 @@ "./loop-runner": { "require": "./runtime/loop-runner.js", "import": "./runtime/loop-runner.js" + }, + "./tracer": { + "require": "./runtime/tracer.js", + "import": "./runtime/tracer.js" + }, + "./replay": { + "require": "./runtime/replay.js", + "import": "./runtime/replay.js" + }, + "./babel-tracer-plugin": { + "require": "./runtime/babel-tracer-plugin.js", + "import": "./runtime/babel-tracer-plugin.js" } }, "scripts": { @@ -88,5 +101,11 @@ "dependencies": { "better-sqlite3": "^12.0.0", "@msgpack/msgpack": "^3.0.0" + }, + "optionalDependencies": { + "@babel/core": "^7.24.0", + "@babel/register": "^7.24.0", + "@babel/preset-env": "^7.24.0", + "@babel/preset-typescript": "^7.24.0" } } diff --git a/packages/codeflash/runtime/babel-tracer-plugin.js b/packages/codeflash/runtime/babel-tracer-plugin.js new file mode 100644 index 000000000..558e0664f --- /dev/null +++ b/packages/codeflash/runtime/babel-tracer-plugin.js @@ -0,0 +1,434 @@ +/** + * Codeflash Babel Tracer Plugin + * + * A Babel plugin that instruments JavaScript/TypeScript functions for tracing. + * This plugin wraps functions with tracing calls to capture: + * - Function arguments + * - Return values + * - Execution time + * + * The plugin transforms: + * function foo(a, b) { return a + b; } + * + * Into: + * const __codeflash_tracer__ = require('codeflash/tracer'); + * function foo(a, b) { + * return __codeflash_tracer__.wrap(function foo(a, b) { return a + b; }, 'foo', '/path/file.js', 1) + * .apply(this, arguments); + * } + * + * Supported function types: + * - FunctionDeclaration: function foo() {} + * - FunctionExpression: const foo = function() {} + * - ArrowFunctionExpression: const foo = () => {} + * - ClassMethod: class Foo { bar() {} } + * - ObjectMethod: const obj = { foo() {} } + * + * Configuration (via plugin options or environment variables): + * - functions: Array of function names to trace (traces all if not set) + * - files: Array of file patterns to trace (traces all if not set) + * - exclude: Array of patterns to exclude from tracing + * + * Usage with @babel/register: + * require('@babel/register')({ + * plugins: [['codeflash/babel-tracer-plugin', { functions: ['myFunc'] }]], + * }); + * + * Environment Variables: + * CODEFLASH_FUNCTIONS - JSON array of functions to trace + * CODEFLASH_TRACE_FILES - JSON array of file patterns to trace + * CODEFLASH_TRACE_EXCLUDE - JSON array of patterns to exclude + */ + +'use strict'; + +const path = require('path'); + +// Parse environment variables for configuration +function getEnvConfig() { + const config = { + functions: null, + files: null, + exclude: null, + }; + + try { + if (process.env.CODEFLASH_FUNCTIONS) { + config.functions = JSON.parse(process.env.CODEFLASH_FUNCTIONS); + } + } catch (e) { + console.error('[codeflash-babel] Failed to parse CODEFLASH_FUNCTIONS:', e.message); + } + + try { + if (process.env.CODEFLASH_TRACE_FILES) { + config.files = JSON.parse(process.env.CODEFLASH_TRACE_FILES); + } + } catch (e) { + console.error('[codeflash-babel] Failed to parse CODEFLASH_TRACE_FILES:', e.message); + } + + try { + if (process.env.CODEFLASH_TRACE_EXCLUDE) { + config.exclude = JSON.parse(process.env.CODEFLASH_TRACE_EXCLUDE); + } + } catch (e) { + console.error('[codeflash-babel] Failed to parse CODEFLASH_TRACE_EXCLUDE:', e.message); + } + + return config; +} + +/** + * Check if a function should be traced based on configuration. + * + * @param {string} funcName - Function name + * @param {string} fileName - File path + * @param {string|null} className - Class name (for methods) + * @param {Object} config - Plugin configuration + * @returns {boolean} - True if function should be traced + */ +function shouldTraceFunction(funcName, fileName, className, config) { + // Check exclude patterns first + if (config.exclude && config.exclude.length > 0) { + for (const pattern of config.exclude) { + if (typeof pattern === 'string') { + if (funcName === pattern || fileName.includes(pattern)) { + return false; + } + } else if (pattern instanceof RegExp) { + if (pattern.test(funcName) || pattern.test(fileName)) { + return false; + } + } + } + } + + // Check file patterns + if (config.files && config.files.length > 0) { + const matchesFile = config.files.some(pattern => { + if (typeof pattern === 'string') { + return fileName.includes(pattern); + } + if (pattern instanceof RegExp) { + return pattern.test(fileName); + } + return false; + }); + if (!matchesFile) return false; + } + + // Check function names + if (config.functions && config.functions.length > 0) { + const matchesName = config.functions.some(f => { + if (typeof f === 'string') { + return f === funcName || f === `${className}.${funcName}`; + } + // Support object format: { function: 'name', file: 'path', class: 'className' } + if (typeof f === 'object' && f !== null) { + if (f.function && f.function !== funcName) return false; + if (f.file && !fileName.includes(f.file)) return false; + if (f.class && f.class !== className) return false; + return true; + } + return false; + }); + if (!matchesName) return false; + } + + return true; +} + +/** + * Check if a path should be excluded from tracing (node_modules, etc.) + * + * @param {string} fileName - File path + * @returns {boolean} - True if file should be excluded + */ +function isExcludedPath(fileName) { + // Always exclude node_modules + if (fileName.includes('node_modules')) return true; + + // Exclude common test runner internals + if (fileName.includes('jest-runner') || fileName.includes('jest-jasmine')) return true; + if (fileName.includes('@vitest')) return true; + + // Exclude this plugin itself + if (fileName.includes('codeflash/runtime')) return true; + if (fileName.includes('babel-tracer-plugin')) return true; + + return false; +} + +/** + * Create the Babel plugin. + * + * @param {Object} babel - Babel object with types (t) + * @returns {Object} - Babel plugin configuration + */ +module.exports = function codeflashTracerPlugin(babel) { + const { types: t } = babel; + + // Merge environment config with plugin options + const envConfig = getEnvConfig(); + + return { + name: 'codeflash-tracer', + + visitor: { + Program: { + enter(programPath, state) { + // Merge options from plugin config and environment + state.codeflashConfig = { + ...envConfig, + ...(state.opts || {}), + }; + + // Track whether we've added the tracer import + state.tracerImportAdded = false; + + // Get file info + state.fileName = state.filename || state.file.opts.filename || 'unknown'; + + // Check if entire file should be excluded + if (isExcludedPath(state.fileName)) { + state.skipFile = true; + return; + } + + state.skipFile = false; + }, + + exit(programPath, state) { + // Add tracer import if we instrumented any functions + if (state.tracerImportAdded) { + const tracerRequire = t.variableDeclaration('const', [ + t.variableDeclarator( + t.identifier('__codeflash_tracer__'), + t.callExpression( + t.identifier('require'), + [t.stringLiteral('codeflash/tracer')] + ) + ), + ]); + + // Add at the beginning of the program + programPath.unshiftContainer('body', tracerRequire); + } + }, + }, + + // Handle: function foo() {} + FunctionDeclaration(path, state) { + if (state.skipFile) return; + if (!path.node.id) return; // Skip anonymous functions + + const funcName = path.node.id.name; + const lineNumber = path.node.loc ? path.node.loc.start.line : 0; + + if (!shouldTraceFunction(funcName, state.fileName, null, state.codeflashConfig)) { + return; + } + + // Transform the function body to wrap with tracing + wrapFunctionBody(t, path, funcName, state.fileName, lineNumber, null); + state.tracerImportAdded = true; + }, + + // Handle: const foo = function() {} or const foo = () => {} + VariableDeclarator(path, state) { + if (state.skipFile) return; + if (!t.isIdentifier(path.node.id)) return; + if (!path.node.init) return; + + const init = path.node.init; + if (!t.isFunctionExpression(init) && !t.isArrowFunctionExpression(init)) { + return; + } + + const funcName = path.node.id.name; + const lineNumber = path.node.loc ? path.node.loc.start.line : 0; + + if (!shouldTraceFunction(funcName, state.fileName, null, state.codeflashConfig)) { + return; + } + + // Wrap the function expression with tracer.wrap() + path.node.init = createWrapperCall(t, init, funcName, state.fileName, lineNumber, null); + state.tracerImportAdded = true; + }, + + // Handle: class Foo { bar() {} } + ClassMethod(path, state) { + if (state.skipFile) return; + if (path.node.kind === 'constructor') return; // Skip constructors for now + + const funcName = path.node.key.name || (path.node.key.value && String(path.node.key.value)); + if (!funcName) return; + + // Get class name from parent + const classPath = path.findParent(p => t.isClassDeclaration(p) || t.isClassExpression(p)); + const className = classPath && classPath.node.id ? classPath.node.id.name : null; + + const lineNumber = path.node.loc ? path.node.loc.start.line : 0; + + if (!shouldTraceFunction(funcName, state.fileName, className, state.codeflashConfig)) { + return; + } + + // Wrap the method body + wrapMethodBody(t, path, funcName, state.fileName, lineNumber, className); + state.tracerImportAdded = true; + }, + + // Handle: const obj = { foo() {} } + ObjectMethod(path, state) { + if (state.skipFile) return; + + const funcName = path.node.key.name || (path.node.key.value && String(path.node.key.value)); + if (!funcName) return; + + const lineNumber = path.node.loc ? path.node.loc.start.line : 0; + + if (!shouldTraceFunction(funcName, state.fileName, null, state.codeflashConfig)) { + return; + } + + // Wrap the method body + wrapMethodBody(t, path, funcName, state.fileName, lineNumber, null); + state.tracerImportAdded = true; + }, + }, + }; +}; + +/** + * Create a __codeflash_tracer__.wrap() call expression. + * + * @param {Object} t - Babel types + * @param {Object} funcNode - The function AST node + * @param {string} funcName - Function name + * @param {string} fileName - File path + * @param {number} lineNumber - Line number + * @param {string|null} className - Class name + * @returns {Object} - Call expression AST node + */ +function createWrapperCall(t, funcNode, funcName, fileName, lineNumber, className) { + const args = [ + funcNode, + t.stringLiteral(funcName), + t.stringLiteral(fileName), + t.numericLiteral(lineNumber), + ]; + + if (className) { + args.push(t.stringLiteral(className)); + } else { + args.push(t.nullLiteral()); + } + + return t.callExpression( + t.memberExpression( + t.identifier('__codeflash_tracer__'), + t.identifier('wrap') + ), + args + ); +} + +/** + * Wrap a function declaration's body with tracing. + * Transforms: + * function foo(a, b) { return a + b; } + * Into: + * function foo(a, b) { + * const __original__ = function(a, b) { return a + b; }; + * return __codeflash_tracer__.wrap(__original__, 'foo', 'file.js', 1, null).apply(this, arguments); + * } + * + * @param {Object} t - Babel types + * @param {Object} path - Babel path + * @param {string} funcName - Function name + * @param {string} fileName - File path + * @param {number} lineNumber - Line number + * @param {string|null} className - Class name + */ +function wrapFunctionBody(t, path, funcName, fileName, lineNumber, className) { + const node = path.node; + const isAsync = node.async; + const isGenerator = node.generator; + + // Create a copy of the original function as an expression + const originalFunc = t.functionExpression( + null, // anonymous + node.params, + node.body, + isGenerator, + isAsync + ); + + // Create the wrapper call + const wrapperCall = createWrapperCall(t, originalFunc, funcName, fileName, lineNumber, className); + + // Create: return __codeflash_tracer__.wrap(...).apply(this, arguments) + const applyCall = t.callExpression( + t.memberExpression(wrapperCall, t.identifier('apply')), + [t.thisExpression(), t.identifier('arguments')] + ); + + const returnStatement = t.returnStatement(applyCall); + + // Replace the function body + node.body = t.blockStatement([returnStatement]); +} + +/** + * Wrap a method's body with tracing. + * Similar to wrapFunctionBody but preserves method semantics. + * + * @param {Object} t - Babel types + * @param {Object} path - Babel path + * @param {string} funcName - Function name + * @param {string} fileName - File path + * @param {number} lineNumber - Line number + * @param {string|null} className - Class name + */ +function wrapMethodBody(t, path, funcName, fileName, lineNumber, className) { + const node = path.node; + const isAsync = node.async; + const isGenerator = node.generator; + + // Create a copy of the original function as an expression + const originalFunc = t.functionExpression( + null, // anonymous + node.params, + node.body, + isGenerator, + isAsync + ); + + // Create the wrapper call + const wrapperCall = createWrapperCall(t, originalFunc, funcName, fileName, lineNumber, className); + + // Create: return __codeflash_tracer__.wrap(...).apply(this, arguments) + const applyCall = t.callExpression( + t.memberExpression(wrapperCall, t.identifier('apply')), + [t.thisExpression(), t.identifier('arguments')] + ); + + let returnStatement; + if (isAsync) { + // For async methods, we need to await the result + returnStatement = t.returnStatement(t.awaitExpression(applyCall)); + } else { + returnStatement = t.returnStatement(applyCall); + } + + // Replace the function body + node.body = t.blockStatement([returnStatement]); +} + +// Export helper functions for testing +module.exports.shouldTraceFunction = shouldTraceFunction; +module.exports.isExcludedPath = isExcludedPath; +module.exports.getEnvConfig = getEnvConfig; diff --git a/packages/codeflash/runtime/index.js b/packages/codeflash/runtime/index.js index 982912c24..864b03066 100644 --- a/packages/codeflash/runtime/index.js +++ b/packages/codeflash/runtime/index.js @@ -8,6 +8,8 @@ * - capturePerf: Capture performance metrics (timing only) * - serialize/deserialize: Value serialization for storage * - comparator: Deep equality comparison + * - tracer: Function tracing for replay test generation + * - replay: Replay test utilities * * Usage (CommonJS): * const { capture, capturePerf } = require('codeflash'); @@ -30,6 +32,22 @@ const comparator = require('./comparator'); // Result comparison (used by CLI) const compareResults = require('./compare-results'); +// Function tracing (for replay test generation) +let tracer = null; +try { + tracer = require('./tracer'); +} catch (e) { + // Tracer may not be available if better-sqlite3 is not installed +} + +// Replay test utilities +let replay = null; +try { + replay = require('./replay'); +} catch (e) { + // Replay may not be available +} + // Re-export all public APIs module.exports = { // === Main Instrumentation API === @@ -88,4 +106,24 @@ module.exports = { // === Feature Detection === hasV8: serializer.hasV8, hasMsgpack: serializer.hasMsgpack, + + // === Function Tracing (for replay test generation) === + tracer: tracer ? { + init: tracer.init, + wrap: tracer.wrap, + createWrapper: tracer.createWrapper, + disable: tracer.disable, + enable: tracer.enable, + getStats: tracer.getStats, + } : null, + + // === Replay Test Utilities === + replay: replay ? { + getNextArg: replay.getNextArg, + getTracesWithMetadata: replay.getTracesWithMetadata, + getTracedFunctions: replay.getTracedFunctions, + getTraceMetadata: replay.getTraceMetadata, + generateReplayTest: replay.generateReplayTest, + createReplayTestFromTrace: replay.createReplayTestFromTrace, + } : null, }; diff --git a/packages/codeflash/runtime/replay.js b/packages/codeflash/runtime/replay.js new file mode 100644 index 000000000..733fed41e --- /dev/null +++ b/packages/codeflash/runtime/replay.js @@ -0,0 +1,454 @@ +/** + * Codeflash Replay Test Utilities + * + * This module provides utilities for generating and running replay tests + * from traced function calls. Replay tests allow verifying that optimized + * code produces the same results as the original code. + * + * Usage: + * const { getNextArg, createReplayTest } = require('codeflash/replay'); + * + * // In a test file: + * describe('Replay tests', () => { + * test.each(getNextArg(traceFile, 'myFunction', '/path/file.js', 25)) + * ('myFunction replay %#', (args) => { + * myFunction(...args); + * }); + * }); + * + * The module supports both Jest and Vitest test frameworks. + */ + +'use strict'; + +const path = require('path'); +const fs = require('fs'); + +// Load the codeflash serializer for argument deserialization +const serializer = require('./serializer'); + +// ============================================================================ +// DATABASE ACCESS +// ============================================================================ + +/** + * Open a SQLite database connection. + * + * @param {string} dbPath - Path to the SQLite database + * @returns {Object|null} - Database connection or null if failed + */ +function openDatabase(dbPath) { + try { + const Database = require('better-sqlite3'); + return new Database(dbPath, { readonly: true }); + } catch (e) { + console.error('[codeflash-replay] Failed to open database:', e.message); + return null; + } +} + +/** + * Get traced function calls from the database. + * + * @param {string} traceFile - Path to the trace SQLite database + * @param {string} functionName - Name of the function + * @param {string} fileName - Path to the source file + * @param {string|null} className - Class name (for methods) + * @param {number} limit - Maximum number of traces to retrieve + * @returns {Array} - Array of traced arguments + */ +function getNextArg(traceFile, functionName, fileName, limit = 25, className = null) { + const db = openDatabase(traceFile); + if (!db) { + return []; + } + + try { + let stmt; + let rows; + + if (className) { + stmt = db.prepare(` + SELECT args FROM function_calls + WHERE function = ? AND filename = ? AND classname = ? AND type = 'call' + ORDER BY time_ns ASC + LIMIT ? + `); + rows = stmt.all(functionName, fileName, className, limit); + } else { + stmt = db.prepare(` + SELECT args FROM function_calls + WHERE function = ? AND filename = ? AND type = 'call' + ORDER BY time_ns ASC + LIMIT ? + `); + rows = stmt.all(functionName, fileName, limit); + } + + db.close(); + + // Deserialize arguments + return rows.map((row, index) => { + try { + const args = serializer.deserialize(row.args); + return args; + } catch (e) { + console.warn(`[codeflash-replay] Failed to deserialize args at index ${index}:`, e.message); + return []; + } + }); + } catch (e) { + console.error('[codeflash-replay] Database query failed:', e.message); + db.close(); + return []; + } +} + +/** + * Get traced function calls with full metadata. + * + * @param {string} traceFile - Path to the trace SQLite database + * @param {string} functionName - Name of the function + * @param {string} fileName - Path to the source file + * @param {string|null} className - Class name (for methods) + * @param {number} limit - Maximum number of traces to retrieve + * @returns {Array} - Array of trace objects with args and metadata + */ +function getTracesWithMetadata(traceFile, functionName, fileName, limit = 25, className = null) { + const db = openDatabase(traceFile); + if (!db) { + return []; + } + + try { + let stmt; + let rows; + + if (className) { + stmt = db.prepare(` + SELECT type, function, classname, filename, line_number, time_ns, args + FROM function_calls + WHERE function = ? AND filename = ? AND classname = ? AND type = 'call' + ORDER BY time_ns ASC + LIMIT ? + `); + rows = stmt.all(functionName, fileName, className, limit); + } else { + stmt = db.prepare(` + SELECT type, function, classname, filename, line_number, time_ns, args + FROM function_calls + WHERE function = ? AND filename = ? AND type = 'call' + ORDER BY time_ns ASC + LIMIT ? + `); + rows = stmt.all(functionName, fileName, limit); + } + + db.close(); + + // Deserialize arguments and return with metadata + return rows.map((row, index) => { + let args; + try { + args = serializer.deserialize(row.args); + } catch (e) { + console.warn(`[codeflash-replay] Failed to deserialize args at index ${index}:`, e.message); + args = []; + } + + return { + args, + function: row.function, + className: row.classname, + fileName: row.filename, + lineNumber: row.line_number, + timeNs: row.time_ns, + }; + }); + } catch (e) { + console.error('[codeflash-replay] Database query failed:', e.message); + db.close(); + return []; + } +} + +/** + * Get all traced functions from the database. + * + * @param {string} traceFile - Path to the trace SQLite database + * @returns {Array} - Array of { function, fileName, className, count } objects + */ +function getTracedFunctions(traceFile) { + const db = openDatabase(traceFile); + if (!db) { + return []; + } + + try { + const stmt = db.prepare(` + SELECT function, filename, classname, COUNT(*) as count + FROM function_calls + WHERE type = 'call' + GROUP BY function, filename, classname + ORDER BY count DESC + `); + const rows = stmt.all(); + db.close(); + + return rows.map(row => ({ + function: row.function, + fileName: row.filename, + className: row.classname, + count: row.count, + })); + } catch (e) { + console.error('[codeflash-replay] Failed to get traced functions:', e.message); + db.close(); + return []; + } +} + +/** + * Get metadata from the trace database. + * + * @param {string} traceFile - Path to the trace SQLite database + * @returns {Object} - Metadata key-value pairs + */ +function getTraceMetadata(traceFile) { + const db = openDatabase(traceFile); + if (!db) { + return {}; + } + + try { + const stmt = db.prepare('SELECT key, value FROM metadata'); + const rows = stmt.all(); + db.close(); + + const metadata = {}; + for (const row of rows) { + metadata[row.key] = row.value; + } + return metadata; + } catch (e) { + console.error('[codeflash-replay] Failed to get metadata:', e.message); + db.close(); + return {}; + } +} + +// ============================================================================ +// TEST GENERATION +// ============================================================================ + +/** + * Generate a Jest/Vitest replay test file. + * + * @param {string} traceFile - Path to the trace SQLite database + * @param {Array} functions - Array of { function, fileName, className, modulePath } to test + * @param {Object} options - Generation options + * @returns {string} - Generated test file content + */ +function generateReplayTest(traceFile, functions, options = {}) { + const { + framework = 'jest', // 'jest' or 'vitest' + maxRunCount = 100, + outputPath = null, + } = options; + + const isVitest = framework === 'vitest'; + + // Build imports section + const imports = []; + + if (isVitest) { + imports.push("import { describe, test } from 'vitest';"); + } + + imports.push("const { getNextArg } = require('codeflash/replay');"); + imports.push(''); + + // Build function imports + for (const func of functions) { + const alias = getFunctionAlias(func.modulePath, func.function, func.className); + + if (func.className) { + // Import class for method testing + imports.push(`const { ${func.className}: ${alias}_class } = require('${func.modulePath}');`); + } else { + // Import function directly + imports.push(`const { ${func.function}: ${alias} } = require('${func.modulePath}');`); + } + } + + imports.push(''); + + // Metadata + const metadata = [ + `const traceFilePath = '${traceFile}';`, + `const functions = ${JSON.stringify(functions.map(f => f.function))};`, + '', + ]; + + // Build test cases + const testCases = []; + + for (const func of functions) { + const alias = getFunctionAlias(func.modulePath, func.function, func.className); + const testName = func.className + ? `${func.className}.${func.function}` + : func.function; + + if (func.className) { + // Method test + testCases.push(` +describe('Replay: ${testName}', () => { + const traces = getNextArg(traceFilePath, '${func.function}', '${func.fileName}', ${maxRunCount}, '${func.className}'); + + test.each(traces.map((args, i) => [i, args]))('call %i', (index, args) => { + // For instance methods, first arg is 'this' context + const [thisArg, ...methodArgs] = args; + const instance = thisArg || new ${alias}_class(); + instance.${func.function}(...methodArgs); + }); +}); +`); + } else { + // Function test + testCases.push(` +describe('Replay: ${testName}', () => { + const traces = getNextArg(traceFilePath, '${func.function}', '${func.fileName}', ${maxRunCount}); + + test.each(traces.map((args, i) => [i, args]))('call %i', (index, args) => { + ${alias}(...args); + }); +}); +`); + } + } + + // Combine all parts + const content = [ + '// Auto-generated replay test by Codeflash', + '// Do not edit this file directly', + '', + ...imports, + ...metadata, + ...testCases, + ].join('\n'); + + // Write to file if outputPath provided + if (outputPath) { + const dir = path.dirname(outputPath); + if (!fs.existsSync(dir)) { + fs.mkdirSync(dir, { recursive: true }); + } + fs.writeFileSync(outputPath, content); + console.log(`[codeflash-replay] Generated test file: ${outputPath}`); + } + + return content; +} + +/** + * Create a function alias for imports to avoid naming conflicts. + * + * @param {string} modulePath - Module path + * @param {string} functionName - Function name + * @param {string|null} className - Class name + * @returns {string} - Alias name + */ +function getFunctionAlias(modulePath, functionName, className = null) { + // Normalize module path to valid identifier + const moduleAlias = modulePath + .replace(/[^a-zA-Z0-9]/g, '_') + .replace(/^_+|_+$/g, ''); + + if (className) { + return `${moduleAlias}_${className}_${functionName}`; + } + return `${moduleAlias}_${functionName}`; +} + +/** + * Create replay tests from a trace file. + * This is the main entry point for Python integration. + * + * @param {string} traceFile - Path to the trace SQLite database + * @param {string} outputPath - Path to write the test file + * @param {Object} options - Generation options + * @returns {Object} - { success, outputPath, functions } + */ +function createReplayTestFromTrace(traceFile, outputPath, options = {}) { + const { + framework = 'jest', + maxRunCount = 100, + projectRoot = process.cwd(), + } = options; + + // Get all traced functions + const tracedFunctions = getTracedFunctions(traceFile); + + if (tracedFunctions.length === 0) { + console.warn('[codeflash-replay] No traced functions found in database'); + return { success: false, outputPath: null, functions: [] }; + } + + // Convert to the format expected by generateReplayTest + const functions = tracedFunctions.map(tf => { + // Calculate module path from file name + let modulePath = tf.fileName; + + // Make relative to project root + if (path.isAbsolute(modulePath)) { + modulePath = path.relative(projectRoot, modulePath); + } + + // Convert to module path (remove .js extension, use forward slashes) + modulePath = './' + modulePath + .replace(/\\/g, '/') + .replace(/\.js$/, '') + .replace(/\.ts$/, ''); + + return { + function: tf.function, + fileName: tf.fileName, + className: tf.className, + modulePath, + }; + }); + + // Generate the test file + const testContent = generateReplayTest(traceFile, functions, { + framework, + maxRunCount, + outputPath, + }); + + return { + success: true, + outputPath, + functions: functions.map(f => f.function), + content: testContent, + }; +} + +// ============================================================================ +// EXPORTS +// ============================================================================ + +module.exports = { + // Core API + getNextArg, + getTracesWithMetadata, + getTracedFunctions, + getTraceMetadata, + + // Test generation + generateReplayTest, + createReplayTestFromTrace, + getFunctionAlias, + + // Database utilities + openDatabase, +}; diff --git a/packages/codeflash/runtime/trace-runner.js b/packages/codeflash/runtime/trace-runner.js new file mode 100644 index 000000000..f8e34148f --- /dev/null +++ b/packages/codeflash/runtime/trace-runner.js @@ -0,0 +1,381 @@ +#!/usr/bin/env node +/** + * Codeflash Trace Runner + * + * Entry point script that runs JavaScript/TypeScript code with function tracing enabled. + * This script: + * 1. Registers Babel with the tracer plugin for AST transformation + * 2. Sets up environment variables for tracing configuration + * 3. Runs the user's script, tests, or module + * + * Usage: + * # Run a script with tracing + * node trace-runner.js script.js + * + * # Run tests with tracing (Jest) + * node trace-runner.js --jest -- --testPathPattern=mytest + * + * # Run tests with tracing (Vitest) + * node trace-runner.js --vitest -- --run + * + * # Run with specific functions to trace + * node trace-runner.js --functions='["myFunc","otherFunc"]' script.js + * + * Environment Variables (also settable via command line): + * CODEFLASH_TRACE_DB - Path to SQLite database for storing traces + * CODEFLASH_PROJECT_ROOT - Project root for relative path calculation + * CODEFLASH_FUNCTIONS - JSON array of functions to trace + * CODEFLASH_MAX_FUNCTION_COUNT - Maximum traces per function (default: 256) + * CODEFLASH_TRACER_TIMEOUT - Timeout in seconds for tracing + * + * For ESM (ECMAScript modules), use the loader flag: + * node --loader ./esm-loader.mjs trace-runner.js script.mjs + */ + +'use strict'; + +const path = require('path'); +const fs = require('fs'); + +// ============================================================================ +// ARGUMENT PARSING +// ============================================================================ + +function parseArgs(args) { + const config = { + traceDb: process.env.CODEFLASH_TRACE_DB || path.join(process.cwd(), 'codeflash.trace.sqlite'), + projectRoot: process.env.CODEFLASH_PROJECT_ROOT || process.cwd(), + functions: process.env.CODEFLASH_FUNCTIONS || null, + maxFunctionCount: process.env.CODEFLASH_MAX_FUNCTION_COUNT || '256', + tracerTimeout: process.env.CODEFLASH_TRACER_TIMEOUT || null, + traceFiles: process.env.CODEFLASH_TRACE_FILES || null, + traceExclude: process.env.CODEFLASH_TRACE_EXCLUDE || null, + jest: false, + vitest: false, + module: false, + script: null, + scriptArgs: [], + }; + + let i = 0; + while (i < args.length) { + const arg = args[i]; + + if (arg === '--trace-db') { + config.traceDb = args[++i]; + } else if (arg.startsWith('--trace-db=')) { + config.traceDb = arg.split('=')[1]; + } else if (arg === '--project-root') { + config.projectRoot = args[++i]; + } else if (arg.startsWith('--project-root=')) { + config.projectRoot = arg.split('=')[1]; + } else if (arg === '--functions') { + config.functions = args[++i]; + } else if (arg.startsWith('--functions=')) { + config.functions = arg.split('=')[1]; + } else if (arg === '--max-function-count') { + config.maxFunctionCount = args[++i]; + } else if (arg.startsWith('--max-function-count=')) { + config.maxFunctionCount = arg.split('=')[1]; + } else if (arg === '--timeout') { + config.tracerTimeout = args[++i]; + } else if (arg.startsWith('--timeout=')) { + config.tracerTimeout = arg.split('=')[1]; + } else if (arg === '--trace-files') { + config.traceFiles = args[++i]; + } else if (arg.startsWith('--trace-files=')) { + config.traceFiles = arg.split('=')[1]; + } else if (arg === '--trace-exclude') { + config.traceExclude = args[++i]; + } else if (arg.startsWith('--trace-exclude=')) { + config.traceExclude = arg.split('=')[1]; + } else if (arg === '--jest') { + config.jest = true; + } else if (arg === '--vitest') { + config.vitest = true; + } else if (arg === '-m' || arg === '--module') { + config.module = true; + } else if (arg === '--') { + // Everything after -- is passed to the script/test runner + config.scriptArgs = args.slice(i + 1); + break; + } else if (arg === '--help' || arg === '-h') { + printHelp(); + process.exit(0); + } else if (!arg.startsWith('-')) { + // First non-flag argument is the script + config.script = arg; + config.scriptArgs = args.slice(i + 1); + break; + } + + i++; + } + + return config; +} + +function printHelp() { + console.log(` +Codeflash Trace Runner - JavaScript Function Tracing + +Usage: + trace-runner [options]