diff --git a/code_to_optimize/js/code_to_optimize_js/bubble_sort.js b/code_to_optimize/js/code_to_optimize_js/bubble_sort.js index 8f3c9ffca..8438a3cdb 100644 --- a/code_to_optimize/js/code_to_optimize_js/bubble_sort.js +++ b/code_to_optimize/js/code_to_optimize_js/bubble_sort.js @@ -7,7 +7,7 @@ * @param {number[]} arr - The array to sort * @returns {number[]} - The sorted array */ -function bubbleSort(arr) { +export function bubbleSort(arr) { const result = arr.slice(); const n = result.length; @@ -29,7 +29,7 @@ function bubbleSort(arr) { * @param {number[]} arr - The array to sort * @returns {number[]} - The sorted array in descending order */ -function bubbleSortDescending(arr) { +export function bubbleSortDescending(arr) { const n = arr.length; const result = [...arr]; diff --git a/code_to_optimize/js/code_to_optimize_js/calculator.js b/code_to_optimize/js/code_to_optimize_js/calculator.js index 3eceb7a70..cecf92ebb 100644 --- a/code_to_optimize/js/code_to_optimize_js/calculator.js +++ b/code_to_optimize/js/code_to_optimize_js/calculator.js @@ -11,7 +11,7 @@ const { sumArray, average, findMax, findMin } = require('./math_helpers'); * @param numbers - Array of numbers to analyze * @returns Object containing sum, average, min, max, and range */ -function calculateStats(numbers) { +export function calculateStats(numbers) { if (numbers.length === 0) { return { sum: 0, @@ -42,7 +42,7 @@ function calculateStats(numbers) { * @param numbers - Array of numbers to normalize * @returns Normalized array */ -function normalizeArray(numbers) { +export function normalizeArray(numbers) { if (numbers.length === 0) return []; const min = findMin(numbers); @@ -62,7 +62,7 @@ function normalizeArray(numbers) { * @param weights - Array of weights (same length as values) * @returns The weighted average */ -function weightedAverage(values, weights) { +export function weightedAverage(values, weights) { if (values.length === 0 || values.length !== weights.length) { return 0; } diff --git a/code_to_optimize/js/code_to_optimize_js/fibonacci.js b/code_to_optimize/js/code_to_optimize_js/fibonacci.js index b0ab2b51c..9ab921d90 100644 --- a/code_to_optimize/js/code_to_optimize_js/fibonacci.js +++ b/code_to_optimize/js/code_to_optimize_js/fibonacci.js @@ -8,7 +8,7 @@ * @param {number} n - The index of the Fibonacci number to calculate * @returns {number} - The nth Fibonacci number */ -function fibonacci(n) { +export function fibonacci(n) { if (n <= 1) { return n; } @@ -20,7 +20,7 @@ function fibonacci(n) { * @param {number} num - The number to check * @returns {boolean} - True if num is a Fibonacci number */ -function isFibonacci(num) { +export function isFibonacci(num) { // A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square const check1 = 5 * num * num + 4; const check2 = 5 * num * num - 4; @@ -33,7 +33,7 @@ function isFibonacci(num) { * @param {number} n - The number to check * @returns {boolean} - True if n is a perfect square */ -function isPerfectSquare(n) { +export function isPerfectSquare(n) { const sqrt = Math.sqrt(n); return sqrt === Math.floor(sqrt); } @@ -43,7 +43,7 @@ function isPerfectSquare(n) { * @param {number} n - The number of Fibonacci numbers to generate * @returns {number[]} - Array of Fibonacci numbers */ -function fibonacciSequence(n) { +export function fibonacciSequence(n) { const result = []; for (let i = 0; i < n; i++) { result.push(fibonacci(i)); diff --git a/code_to_optimize/js/code_to_optimize_js/math_helpers.js b/code_to_optimize/js/code_to_optimize_js/math_helpers.js index f6e7c9662..72a320919 100644 --- a/code_to_optimize/js/code_to_optimize_js/math_helpers.js +++ b/code_to_optimize/js/code_to_optimize_js/math_helpers.js @@ -8,7 +8,7 @@ * @param numbers - Array of numbers to sum * @returns The sum of all numbers */ -function sumArray(numbers) { +export function sumArray(numbers) { // Intentionally inefficient - using reduce with spread operator let result = 0; for (let i = 0; i < numbers.length; i++) { @@ -22,7 +22,7 @@ function sumArray(numbers) { * @param numbers - Array of numbers * @returns The average value */ -function average(numbers) { +export function average(numbers) { if (numbers.length === 0) return 0; return sumArray(numbers) / numbers.length; } @@ -32,7 +32,7 @@ function average(numbers) { * @param numbers - Array of numbers * @returns The maximum value */ -function findMax(numbers) { +export function findMax(numbers) { if (numbers.length === 0) return -Infinity; // Intentionally inefficient - sorting instead of linear scan @@ -45,7 +45,7 @@ function findMax(numbers) { * @param numbers - Array of numbers * @returns The minimum value */ -function findMin(numbers) { +export function findMin(numbers) { if (numbers.length === 0) return Infinity; // Intentionally inefficient - sorting instead of linear scan diff --git a/code_to_optimize/js/code_to_optimize_js/string_utils.js b/code_to_optimize/js/code_to_optimize_js/string_utils.js index 6881943e5..9c4eb5a04 100644 --- a/code_to_optimize/js/code_to_optimize_js/string_utils.js +++ b/code_to_optimize/js/code_to_optimize_js/string_utils.js @@ -7,7 +7,7 @@ * @param {string} str - The string to reverse * @returns {string} - The reversed string */ -function reverseString(str) { +export function reverseString(str) { // Intentionally inefficient O(n²) implementation for testing let result = ''; for (let i = str.length - 1; i >= 0; i--) { @@ -27,7 +27,7 @@ function reverseString(str) { * @param {string} str - The string to check * @returns {boolean} - True if str is a palindrome */ -function isPalindrome(str) { +export function isPalindrome(str) { const cleaned = str.toLowerCase().replace(/[^a-z0-9]/g, ''); return cleaned === reverseString(cleaned); } @@ -38,7 +38,7 @@ function isPalindrome(str) { * @param {string} sub - The substring to count * @returns {number} - Number of occurrences */ -function countOccurrences(str, sub) { +export function countOccurrences(str, sub) { let count = 0; let pos = 0; @@ -57,7 +57,7 @@ function countOccurrences(str, sub) { * @param {string[]} strs - Array of strings * @returns {string} - The longest common prefix */ -function longestCommonPrefix(strs) { +export function longestCommonPrefix(strs) { if (strs.length === 0) return ''; if (strs.length === 1) return strs[0]; @@ -78,7 +78,7 @@ function longestCommonPrefix(strs) { * @param {string} str - The string to convert * @returns {string} - The title-cased string */ -function toTitleCase(str) { +export function toTitleCase(str) { return str .toLowerCase() .split(' ') diff --git a/code_to_optimize/js/code_to_optimize_js_cjs/fibonacci.js b/code_to_optimize/js/code_to_optimize_js_cjs/fibonacci.js index 17de243bc..cdb9bd5f8 100644 --- a/code_to_optimize/js/code_to_optimize_js_cjs/fibonacci.js +++ b/code_to_optimize/js/code_to_optimize_js_cjs/fibonacci.js @@ -9,7 +9,7 @@ * @param {number} n - The index of the Fibonacci number to calculate * @returns {number} The nth Fibonacci number */ -function fibonacci(n) { +export function fibonacci(n) { if (n <= 1) { return n; } @@ -21,7 +21,7 @@ function fibonacci(n) { * @param {number} num - The number to check * @returns {boolean} True if num is a Fibonacci number */ -function isFibonacci(num) { +export function isFibonacci(num) { // A number is Fibonacci if one of (5*n*n + 4) or (5*n*n - 4) is a perfect square const check1 = 5 * num * num + 4; const check2 = 5 * num * num - 4; @@ -33,7 +33,7 @@ function isFibonacci(num) { * @param {number} n - The number to check * @returns {boolean} True if n is a perfect square */ -function isPerfectSquare(n) { +export function isPerfectSquare(n) { const sqrt = Math.sqrt(n); return sqrt === Math.floor(sqrt); } @@ -43,7 +43,7 @@ function isPerfectSquare(n) { * @param {number} n - The number of Fibonacci numbers to generate * @returns {number[]} Array of Fibonacci numbers */ -function fibonacciSequence(n) { +export function fibonacciSequence(n) { const result = []; for (let i = 0; i < n; i++) { result.push(fibonacci(i)); diff --git a/code_to_optimize/js/code_to_optimize_js_cjs/fibonacci_class.js b/code_to_optimize/js/code_to_optimize_js_cjs/fibonacci_class.js index 24621ee7f..9c816ada0 100644 --- a/code_to_optimize/js/code_to_optimize_js_cjs/fibonacci_class.js +++ b/code_to_optimize/js/code_to_optimize_js_cjs/fibonacci_class.js @@ -3,7 +3,7 @@ * Intentionally inefficient for optimization testing. */ -class FibonacciCalculator { +export class FibonacciCalculator { constructor() { // No initialization needed } diff --git a/code_to_optimize/js/code_to_optimize_vitest/package-lock.json b/code_to_optimize/js/code_to_optimize_vitest/package-lock.json index ac3d39afd..ef24dc459 100644 --- a/code_to_optimize/js/code_to_optimize_vitest/package-lock.json +++ b/code_to_optimize/js/code_to_optimize_vitest/package-lock.json @@ -15,7 +15,7 @@ } }, "../../../packages/codeflash": { - "version": "0.7.0", + "version": "0.8.0", "dev": true, "hasInstallScript": true, "license": "MIT", diff --git a/codeflash/code_utils/time_utils.py b/codeflash/code_utils/time_utils.py index e44c279d3..ff04b5037 100644 --- a/codeflash/code_utils/time_utils.py +++ b/codeflash/code_utils/time_utils.py @@ -1,10 +1,5 @@ from __future__ import annotations -import datetime as dt -import re - -import humanize - def humanize_runtime(time_in_ns: int) -> str: runtime_human: str = str(time_in_ns) @@ -14,22 +9,32 @@ def humanize_runtime(time_in_ns: int) -> str: if time_in_ns / 1000 >= 1: time_micro = float(time_in_ns) / 1000 - runtime_human = humanize.precisedelta(dt.timedelta(microseconds=time_micro), minimum_unit="microseconds") - units = re.split(r",|\s", runtime_human)[1] - - if units in {"microseconds", "microsecond"}: + # Direct unit determination and formatting without external library + if time_micro < 1000: runtime_human = f"{time_micro:.3g}" - elif units in {"milliseconds", "millisecond"}: - runtime_human = "%.3g" % (time_micro / 1000) - elif units in {"seconds", "second"}: - runtime_human = "%.3g" % (time_micro / (1000**2)) - elif units in {"minutes", "minute"}: - runtime_human = "%.3g" % (time_micro / (60 * 1000**2)) - elif units in {"hour", "hours"}: # hours - runtime_human = "%.3g" % (time_micro / (3600 * 1000**2)) + units = "microseconds" if time_micro >= 2 else "microsecond" + elif time_micro < 1000000: + time_milli = time_micro / 1000 + runtime_human = f"{time_milli:.3g}" + units = "milliseconds" if time_milli >= 2 else "millisecond" + elif time_micro < 60000000: + time_sec = time_micro / 1000000 + runtime_human = f"{time_sec:.3g}" + units = "seconds" if time_sec >= 2 else "second" + elif time_micro < 3600000000: + time_min = time_micro / 60000000 + runtime_human = f"{time_min:.3g}" + units = "minutes" if time_min >= 2 else "minute" + elif time_micro < 86400000000: + time_hour = time_micro / 3600000000 + runtime_human = f"{time_hour:.3g}" + units = "hours" if time_hour >= 2 else "hour" else: # days - runtime_human = "%.3g" % (time_micro / (24 * 3600 * 1000**2)) + time_day = time_micro / 86400000000 + runtime_human = f"{time_day:.3g}" + units = "days" if time_day >= 2 else "day" + runtime_human_parts = str(runtime_human).split(".") if len(runtime_human_parts[0]) == 1: if runtime_human_parts[0] == "1" and len(runtime_human_parts) > 1: diff --git a/codeflash/github/PrComment.py b/codeflash/github/PrComment.py index 7416329bb..1a78e79e4 100644 --- a/codeflash/github/PrComment.py +++ b/codeflash/github/PrComment.py @@ -26,10 +26,10 @@ class PrComment: def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]]: report_table: dict[str, dict[str, int]] = {} - for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items(): + for test_type, test_result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items(): name = test_type.to_name() if name: - report_table[name] = result + report_table[name] = test_result result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = { "optimization_explanation": self.optimization_explanation, @@ -45,8 +45,8 @@ def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[B } if self.original_async_throughput is not None and self.best_async_throughput is not None: - result["original_async_throughput"] = str(self.original_async_throughput) - result["best_async_throughput"] = str(self.best_async_throughput) + result["original_async_throughput"] = self.original_async_throughput + result["best_async_throughput"] = self.best_async_throughput return result diff --git a/codeflash/languages/javascript/instrument.py b/codeflash/languages/javascript/instrument.py index 30e7fff7a..a180c593f 100644 --- a/codeflash/languages/javascript/instrument.py +++ b/codeflash/languages/javascript/instrument.py @@ -962,3 +962,173 @@ def instrument_generated_js_test( mode=mode, remove_assertions=True, ) + + +def fix_imports_inside_test_blocks(test_code: str) -> str: + """Fix import statements that appear inside test/it blocks. + + JavaScript/TypeScript `import` statements must be at the top level of a module. + The AI sometimes generates imports inside test functions, which is invalid syntax. + + This function detects such patterns and converts them to dynamic require() calls + which are valid inside functions. + + Args: + test_code: The generated test code. + + Returns: + Fixed test code with imports converted to require() inside functions. + + """ + if not test_code or not test_code.strip(): + return test_code + + # Pattern to match import statements inside functions + # This captures imports that appear after function/test block openings + # We look for lines that: + # 1. Start with whitespace (indicating they're inside a block) + # 2. Have an import statement + + lines = test_code.split("\n") + result_lines = [] + brace_depth = 0 + in_test_block = False + + for line in lines: + stripped = line.strip() + + # Track brace depth to know if we're inside a block + # Count braces, but ignore braces in strings (simplified check) + for char in stripped: + if char == "{": + brace_depth += 1 + elif char == "}": + brace_depth -= 1 + + # Check if we're entering a test/it/describe block + if re.match(r"^(test|it|describe|beforeEach|afterEach|beforeAll|afterAll)\s*\(", stripped): + in_test_block = True + + # Check for import statement inside a block (brace_depth > 0 means we're inside a function/block) + if brace_depth > 0 and stripped.startswith("import "): + # Convert ESM import to require + # Pattern: import { name } from 'module' -> const { name } = require('module') + # Pattern: import name from 'module' -> const name = require('module') + + named_import = re.match(r"import\s+\{([^}]+)\}\s+from\s+['\"]([^'\"]+)['\"]", stripped) + default_import = re.match(r"import\s+(\w+)\s+from\s+['\"]([^'\"]+)['\"]", stripped) + namespace_import = re.match(r"import\s+\*\s+as\s+(\w+)\s+from\s+['\"]([^'\"]+)['\"]", stripped) + + leading_whitespace = line[: len(line) - len(line.lstrip())] + + if named_import: + names = named_import.group(1) + module = named_import.group(2) + new_line = f"{leading_whitespace}const {{{names}}} = require('{module}');" + result_lines.append(new_line) + logger.debug(f"Fixed import inside block: {stripped} -> {new_line.strip()}") + continue + if default_import: + name = default_import.group(1) + module = default_import.group(2) + new_line = f"{leading_whitespace}const {name} = require('{module}');" + result_lines.append(new_line) + logger.debug(f"Fixed import inside block: {stripped} -> {new_line.strip()}") + continue + if namespace_import: + name = namespace_import.group(1) + module = namespace_import.group(2) + new_line = f"{leading_whitespace}const {name} = require('{module}');" + result_lines.append(new_line) + logger.debug(f"Fixed import inside block: {stripped} -> {new_line.strip()}") + continue + + result_lines.append(line) + + return "\n".join(result_lines) + + +def fix_jest_mock_paths(test_code: str, test_file_path: Path, source_file_path: Path, tests_root: Path) -> str: + """Fix relative paths in jest.mock() calls to be correct from the test file's location. + + The AI sometimes generates jest.mock() calls with paths relative to the source file + instead of the test file. For example: + - Source at `src/queue/queue.ts` imports `../environment` (-> src/environment) + - Test at `tests/test.test.ts` generates `jest.mock('../environment')` (-> ./environment, wrong!) + - Should generate `jest.mock('../src/environment')` + + This function detects relative mock paths and adjusts them based on the test file's + location relative to the source file's directory. + + Args: + test_code: The generated test code. + test_file_path: Path to the test file being generated. + source_file_path: Path to the source file being tested. + tests_root: Root directory of the tests. + + Returns: + Fixed test code with corrected mock paths. + + """ + if not test_code or not test_code.strip(): + return test_code + + import os + + # Get the directory containing the source file and the test file + source_dir = source_file_path.resolve().parent + test_dir = test_file_path.resolve().parent + project_root = tests_root.resolve().parent if tests_root.name == "tests" else tests_root.resolve() + + # Pattern to match jest.mock() or jest.doMock() with relative paths + mock_pattern = re.compile(r"(jest\.(?:mock|doMock)\s*\(\s*['\"])(\.\./[^'\"]+|\.\/[^'\"]+)(['\"])") + + def fix_mock_path(match: re.Match[str]) -> str: + original = match.group(0) + prefix = match.group(1) + rel_path = match.group(2) + suffix = match.group(3) + + # Resolve the path as if it were relative to the source file's directory + # (which is how the AI often generates it) + source_relative_resolved = (source_dir / rel_path).resolve() + + # Check if this resolved path exists or if adjusting it would make more sense + # Calculate what the correct relative path from the test file should be + try: + # First, try to find if the path makes sense from the test directory + test_relative_resolved = (test_dir / rel_path).resolve() + + # If the path exists relative to test dir, keep it + if test_relative_resolved.exists() or ( + test_relative_resolved.with_suffix(".ts").exists() + or test_relative_resolved.with_suffix(".js").exists() + or test_relative_resolved.with_suffix(".tsx").exists() + or test_relative_resolved.with_suffix(".jsx").exists() + ): + return original # Keep original, it's valid + + # If path exists relative to source dir, recalculate from test dir + if source_relative_resolved.exists() or ( + source_relative_resolved.with_suffix(".ts").exists() + or source_relative_resolved.with_suffix(".js").exists() + or source_relative_resolved.with_suffix(".tsx").exists() + or source_relative_resolved.with_suffix(".jsx").exists() + ): + # Calculate the correct relative path from test_dir to source_relative_resolved + new_rel_path = os.path.relpath(str(source_relative_resolved), str(test_dir)) + # Ensure it starts with ./ or ../ + if not new_rel_path.startswith("../") and not new_rel_path.startswith("./"): + new_rel_path = f"./{new_rel_path}" + # Use forward slashes + new_rel_path = new_rel_path.replace("\\", "/") + + logger.debug(f"Fixed jest.mock path: {rel_path} -> {new_rel_path}") + return f"{prefix}{new_rel_path}{suffix}" + + except (ValueError, OSError): + pass # Path resolution failed, keep original + + return original # Keep original if we can't fix it + + return mock_pattern.sub(fix_mock_path, test_code) diff --git a/codeflash/languages/javascript/parse.py b/codeflash/languages/javascript/parse.py index 0d62b50b4..c16c551bf 100644 --- a/codeflash/languages/javascript/parse.py +++ b/codeflash/languages/javascript/parse.py @@ -297,7 +297,7 @@ def parse_jest_test_xml( # Infer test type from filename pattern filename = test_file_path.name if "__perf_test_" in filename or "_perf_test_" in filename: - test_type = TestType.GENERATED_PERFORMANCE + test_type = TestType.GENERATED_REGRESSION # Performance tests are still generated regression tests elif "__unit_test_" in filename or "_unit_test_" in filename: test_type = TestType.GENERATED_REGRESSION else: diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index 17c3b1021..68a4b9c11 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -104,6 +104,12 @@ def discover_functions( if not criteria.include_async and func.is_async: continue + # Skip non-exported functions (can't be imported in tests) + # Exception: nested functions and methods are allowed if their parent is exported + if not func.is_exported and not func.parent_function: + logger.debug(f"Skipping non-exported function: {func.name}") # noqa: G004 + continue + # Build parents list parents: list[FunctionParent] = [] if func.class_name: @@ -326,8 +332,14 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path, else: target_code = "" + imports = analyzer.find_imports(source) + + # Find helper functions called by target (needed before class wrapping to find same-class helpers) + helpers = self._find_helper_functions(function, source, analyzer, imports, module_root) + # For class methods, wrap the method in its class definition # This is necessary because method definition syntax is only valid inside a class body + same_class_helper_names: set[str] = set() if function.is_method and function.parents: class_name = None for parent in function.parents: @@ -336,17 +348,26 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path, break if class_name: + # Find same-class helper methods that need to be included inside the class wrapper + same_class_helpers = self._find_same_class_helpers( + class_name, function.function_name, helpers, tree_functions, lines + ) + same_class_helper_names = {h[0] for h in same_class_helpers} # method names + # Find the class definition in the source to get proper indentation, JSDoc, constructor, and fields class_info = self._find_class_definition(source, class_name, analyzer, function.function_name) if class_info: class_jsdoc, class_indent, constructor_code, fields_code = class_info - # Build the class body with fields, constructor, and target method + # Build the class body with fields, constructor, target method, and same-class helpers class_body_parts = [] if fields_code: class_body_parts.append(fields_code) if constructor_code: class_body_parts.append(constructor_code) class_body_parts.append(target_code) + # Add same-class helper methods inside the class body + for _helper_name, helper_source in same_class_helpers: + class_body_parts.append(helper_source) class_body = "\n".join(class_body_parts) # Wrap the method in a class definition with context @@ -357,13 +378,16 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path, else: target_code = f"{class_indent}class {class_name} {{\n{class_body}{class_indent}}}\n" else: - # Fallback: wrap with no indentation - target_code = f"class {class_name} {{\n{target_code}}}\n" - - imports = analyzer.find_imports(source) + # Fallback: wrap with no indentation, including same-class helpers + helper_code = "\n".join(h[1] for h in same_class_helpers) + if helper_code: + target_code = f"class {class_name} {{\n{target_code}\n{helper_code}}}\n" + else: + target_code = f"class {class_name} {{\n{target_code}}}\n" - # Find helper functions called by target - helpers = self._find_helper_functions(function, source, analyzer, imports, module_root) + # Filter out same-class helpers from the helpers list (they're already inside the class wrapper) + if same_class_helper_names: + helpers = [h for h in helpers if h.name not in same_class_helper_names] # Extract import statements as strings import_lines = [] @@ -546,6 +570,49 @@ def _extract_class_context( return (constructor_code, fields_code) + def _find_same_class_helpers( + self, + class_name: str, + target_method_name: str, + helpers: list[HelperFunction], + tree_functions: list, + lines: list[str], + ) -> list[tuple[str, str]]: + """Find helper methods that belong to the same class as the target method. + + These helpers need to be included inside the class wrapper rather than + appended outside, because they may use class-specific syntax like 'private'. + + Args: + class_name: Name of the class containing the target method. + target_method_name: Name of the target method (to exclude). + helpers: List of all helper functions found. + tree_functions: List of FunctionNode from tree-sitter analysis. + lines: Source code split into lines. + + Returns: + List of (method_name, source_code) tuples for same-class helpers. + + """ + same_class_helpers: list[tuple[str, str]] = [] + + # Build a set of helper names for quick lookup + helper_names = {h.name for h in helpers} + + # Names to exclude from same-class helpers (target method and constructor) + exclude_names = {target_method_name, "constructor"} + + # Find methods in tree_functions that belong to the same class and are helpers + for func in tree_functions: + if func.class_name == class_name and func.name in helper_names and func.name not in exclude_names: + # Extract source including JSDoc if present + effective_start = func.doc_start_line or func.start_line + helper_lines = lines[effective_start - 1 : func.end_line] + helper_source = "".join(helper_lines) + same_class_helpers.append((func.name, helper_source)) + + return same_class_helpers + def _find_helper_functions( self, function: FunctionToOptimize, diff --git a/codeflash/languages/javascript/test_runner.py b/codeflash/languages/javascript/test_runner.py index c65adfa7b..ded22a514 100644 --- a/codeflash/languages/javascript/test_runner.py +++ b/codeflash/languages/javascript/test_runner.py @@ -535,7 +535,6 @@ def run_jest_behavioral_tests( # Get test files to run test_files = [str(file.instrumented_behavior_file_path) for file in test_paths.test_files] - # Use provided project_root, or detect it as fallback if project_root is None and test_files: first_test_file = Path(test_files[0]) @@ -774,14 +773,12 @@ def run_jest_benchmarking_tests( # Get performance test files test_files = [str(file.benchmarking_file_path) for file in test_paths.test_files if file.benchmarking_file_path] - # Use provided project_root, or detect it as fallback if project_root is None and test_files: first_test_file = Path(test_files[0]) project_root = _find_node_project_root(first_test_file) effective_cwd = project_root if project_root else cwd - logger.debug(f"Jest benchmarking working directory: {effective_cwd}") # Ensure the codeflash npm package is installed _ensure_runtime_files(effective_cwd) @@ -792,7 +789,7 @@ def run_jest_benchmarking_tests( "jest", "--reporters=default", "--reporters=jest-junit", - "--runInBand", # Ensure serial execution even though runner enforces it + "--runInBand", # Ensure serial execution "--forceExit", "--runner=codeflash/loop-runner", # Use custom loop runner for in-process looping ] @@ -844,6 +841,12 @@ def run_jest_benchmarking_tests( jest_env["CODEFLASH_PERF_STABILITY_CHECK"] = "true" if stability_check else "false" jest_env["CODEFLASH_LOOP_INDEX"] = "1" # Initial value for compatibility + # Enable console output for timing markers + # Some projects mock console.log in test setup (e.g., based on LOG_LEVEL or DEBUG) + # We need console.log to work for capturePerf timing markers + jest_env["LOG_LEVEL"] = "info" # Disable console.log mocking in projects that check LOG_LEVEL + jest_env["DEBUG"] = "1" # Disable console.log mocking in projects that check DEBUG + # Configure ESM support if project uses ES Modules _configure_esm_environment(jest_env, effective_cwd) diff --git a/codeflash/languages/treesitter_utils.py b/codeflash/languages/treesitter_utils.py index f4b7ead43..530a2c47a 100644 --- a/codeflash/languages/treesitter_utils.py +++ b/codeflash/languages/treesitter_utils.py @@ -69,6 +69,7 @@ class FunctionNode: parent_function: str | None source_text: str doc_start_line: int | None = None # Line where JSDoc comment starts (or None if no JSDoc) + is_exported: bool = False # Whether the function is exported @dataclass @@ -292,6 +293,7 @@ def _extract_function_info( is_generator = False is_method = False is_arrow = node.type == "arrow_function" + is_exported = False # Check for async modifier for child in node.children: @@ -303,6 +305,12 @@ def _extract_function_info( if "generator" in node.type: is_generator = True + # Check if function is exported + # For function_declaration: check if parent is export_statement + # For arrow functions: check if parent variable_declarator's grandparent is export_statement + # For CommonJS: check module.exports = { name } or exports.name = ... + is_exported = self._is_node_exported(node, source_bytes) + # Get function name based on node type if node.type in ("function_declaration", "generator_function_declaration"): name_node = node.child_by_field_name("name") @@ -352,8 +360,157 @@ def _extract_function_info( parent_function=current_function, source_text=source_text, doc_start_line=doc_start_line, + is_exported=is_exported, ) + def _is_node_exported(self, node: Node, source_bytes: bytes | None = None) -> bool: + """Check if a function node is exported. + + Handles various export patterns: + - export function foo() {} + - export const foo = () => {} + - export default function foo() {} + - Class methods in exported classes + - module.exports = { foo } (CommonJS) + - exports.foo = ... (CommonJS) + + Args: + node: The function node to check. + source_bytes: Source code bytes (needed for CommonJS export detection). + + Returns: + True if the function is exported, False otherwise. + + """ + # Check direct parent for export_statement + if node.parent and node.parent.type == "export_statement": + return True + + # For arrow functions and function expressions assigned to variables + # e.g., export const foo = () => {} + if node.type in ("arrow_function", "function_expression", "generator_function"): + parent = node.parent + if parent and parent.type == "variable_declarator": + grandparent = parent.parent + if grandparent and grandparent.type in ("lexical_declaration", "variable_declaration"): + great_grandparent = grandparent.parent + if great_grandparent and great_grandparent.type == "export_statement": + return True + + # For methods in exported classes + if node.type == "method_definition": + # Walk up to find class_declaration + current = node.parent + while current: + if current.type in ("class_declaration", "class"): + # Check if this class is exported via ES module export + if current.parent and current.parent.type == "export_statement": + return True + # Check if class is exported via CommonJS + if source_bytes: + class_name_node = current.child_by_field_name("name") + if class_name_node: + class_name = self.get_node_text(class_name_node, source_bytes) + if self._is_name_in_commonjs_exports(node, class_name, source_bytes): + return True + break + current = current.parent + + # Check CommonJS exports: module.exports = { foo } or exports.foo = ... + if source_bytes: + func_name = self._get_function_name_for_export_check(node, source_bytes) + if func_name and self._is_name_in_commonjs_exports(node, func_name, source_bytes): + return True + + return False + + def _get_function_name_for_export_check(self, node: Node, source_bytes: bytes) -> str | None: + """Get the function name for export checking.""" + if node.type in ("function_declaration", "generator_function_declaration"): + name_node = node.child_by_field_name("name") + if name_node: + return self.get_node_text(name_node, source_bytes) + elif node.type in ("arrow_function", "function_expression", "generator_function"): + # Get name from variable assignment + parent = node.parent + if parent and parent.type == "variable_declarator": + name_node = parent.child_by_field_name("name") + if name_node and name_node.type == "identifier": + return self.get_node_text(name_node, source_bytes) + return None + + def _is_name_in_commonjs_exports(self, node: Node, name: str, source_bytes: bytes) -> bool: + """Check if a name is exported via CommonJS module.exports or exports. + + Handles patterns like: + - module.exports = { foo, bar } + - module.exports = { foo: someFunc } + - exports.foo = ... + - module.exports.foo = ... + + Args: + node: Any node in the tree (used to find the program root). + name: The name to check for in exports. + source_bytes: Source code bytes. + + Returns: + True if the name is in CommonJS exports. + + """ + # Walk up to find program root + root = node + while root.parent: + root = root.parent + + # Search for CommonJS export patterns in program children + for child in root.children: + if child.type == "expression_statement": + # Look for assignment expressions + for expr in child.children: + if expr.type == "assignment_expression": + if self._check_commonjs_assignment_exports(expr, name, source_bytes): + return True + + return False + + def _check_commonjs_assignment_exports(self, node: Node, name: str, source_bytes: bytes) -> bool: + """Check if a CommonJS assignment exports the given name.""" + left_node = node.child_by_field_name("left") + right_node = node.child_by_field_name("right") + + if not left_node or not right_node: + return False + + left_text = self.get_node_text(left_node, source_bytes) + + # Check module.exports = { name, ... } or module.exports = { key: name, ... } + if left_text == "module.exports" and right_node.type == "object": + for child in right_node.children: + if child.type == "shorthand_property_identifier": + # { foo } - shorthand export + if self.get_node_text(child, source_bytes) == name: + return True + elif child.type == "pair": + # { key: value } - check both key and value + key_node = child.child_by_field_name("key") + value_node = child.child_by_field_name("value") + if key_node and self.get_node_text(key_node, source_bytes) == name: + return True + if value_node and value_node.type == "identifier": + if self.get_node_text(value_node, source_bytes) == name: + return True + + # Check module.exports = name (single export) + if left_text == "module.exports" and right_node.type == "identifier": + if self.get_node_text(right_node, source_bytes) == name: + return True + + # Check module.exports.name = ... or exports.name = ... + if left_text in {f"module.exports.{name}", f"exports.{name}"}: + return True + + return False + def _find_preceding_jsdoc(self, node: Node, source_bytes: bytes) -> int | None: """Find JSDoc comment immediately preceding a function node. @@ -1580,9 +1737,9 @@ def get_analyzer_for_file(file_path: Path) -> TreeSitterAnalyzer: """ suffix = file_path.suffix.lower() - if suffix in (".ts",): + if suffix == ".ts": return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT) - if suffix in (".tsx",): + if suffix == ".tsx": return TreeSitterAnalyzer(TreeSitterLanguage.TSX) # Default to JavaScript for .js, .jsx, .mjs, .cjs return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT) diff --git a/codeflash/models/test_type.py b/codeflash/models/test_type.py index e3f196756..154e3f7f2 100644 --- a/codeflash/models/test_type.py +++ b/codeflash/models/test_type.py @@ -10,9 +10,7 @@ class TestType(Enum): INIT_STATE_TEST = 6 def to_name(self) -> str: - if self is TestType.INIT_STATE_TEST: - return "" - return _TO_NAME_MAP[self] + return _TO_NAME_MAP.get(self, "") _TO_NAME_MAP: dict[TestType, str] = { diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 08f78ba58..194d1676b 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -604,6 +604,11 @@ def generate_and_instrument_tests( f.write(generated_test.instrumented_behavior_test_source) logger.debug(f"[PIPELINE] Wrote behavioral test to {generated_test.behavior_file_path}") + # Save perf test source for debugging + debug_file_path = get_run_tmp_file(Path("perf_test_debug.test.ts")) + with debug_file_path.open("w", encoding="utf-8") as debug_f: + debug_f.write(generated_test.instrumented_perf_test_source) + with generated_test.perf_file_path.open("w", encoding="utf8") as f: f.write(generated_test.instrumented_perf_test_source) logger.debug(f"[PIPELINE] Wrote perf test to {generated_test.perf_file_path}") @@ -2103,7 +2108,7 @@ def process_review( formatted_generated_test = format_generated_code(concolic_test_str, self.args.formatter_cmds) generated_tests_str += f"```{code_lang}\n{formatted_generated_test}\n```\n\n" - existing_tests, replay_tests, _ = existing_tests_source_for( + existing_tests, replay_tests, _concolic_tests = existing_tests_source_for( self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), function_to_all_tests, test_cfg=self.test_cfg, diff --git a/codeflash/verification/coverage_utils.py b/codeflash/verification/coverage_utils.py index 54e8a65ba..cf5339a02 100644 --- a/codeflash/verification/coverage_utils.py +++ b/codeflash/verification/coverage_utils.py @@ -58,7 +58,9 @@ def load_from_jest_json( source_path_str = str(source_code_path.resolve()) for file_path, file_data in coverage_data.items(): - if file_path == source_path_str or file_path.endswith(source_code_path.name): + # Match exact path or path ending with full relative path from src/ + # Avoid matching files with same name in different directories (e.g., db/utils.ts vs utils/utils.ts) + if file_path == source_path_str or file_path.endswith(str(source_code_path)): file_coverage = file_data break diff --git a/codeflash/verification/verifier.py b/codeflash/verification/verifier.py index f351bd262..78bd2e4ab 100644 --- a/codeflash/verification/verifier.py +++ b/codeflash/verification/verifier.py @@ -42,7 +42,20 @@ def generate_tests( source_file = Path(function_to_optimize.file_path) project_module_system = detect_module_system(test_cfg.tests_project_rootdir, source_file) - logger.debug(f"Detected module system: {project_module_system}") + + # For JavaScript, calculate the correct import path from the actual test location + # (test_path) to the source file, not from tests_root + import os + + source_file_abs = source_file.resolve().with_suffix("") + test_dir_abs = test_path.resolve().parent + # Compute relative path from test directory to source file + rel_import_path = os.path.relpath(str(source_file_abs), str(test_dir_abs)) + # Ensure path starts with ./ or ../ for JavaScript/TypeScript imports + if not rel_import_path.startswith("../"): + rel_import_path = f"./{rel_import_path}" + # Keep as string since Path() normalizes away the ./ prefix + module_path = rel_import_path response = aiservice_client.generate_regression_tests( source_code_being_tested=source_code_being_tested, @@ -66,6 +79,8 @@ def generate_tests( if is_javascript(): from codeflash.languages.javascript.instrument import ( TestingMode, + fix_imports_inside_test_blocks, + fix_jest_mock_paths, instrument_generated_js_test, validate_and_fix_import_style, ) @@ -76,6 +91,14 @@ def generate_tests( source_file = Path(function_to_optimize.file_path) + # Fix import statements that appear inside test blocks (invalid JS syntax) + generated_test_source = fix_imports_inside_test_blocks(generated_test_source) + + # Fix relative paths in jest.mock() calls + generated_test_source = fix_jest_mock_paths( + generated_test_source, test_path, source_file, test_cfg.tests_project_rootdir + ) + # Validate and fix import styles (default vs named exports) generated_test_source = validate_and_fix_import_style( generated_test_source, source_file, function_to_optimize.function_name diff --git a/packages/codeflash/runtime/capture.js b/packages/codeflash/runtime/capture.js index eabcee539..8b8d91c33 100644 --- a/packages/codeflash/runtime/capture.js +++ b/packages/codeflash/runtime/capture.js @@ -87,6 +87,8 @@ if (!process[PERF_STATE_KEY]) { shouldStop: false, // Flag to stop all further looping currentBatch: 0, // Current batch number (incremented by runner) invocationLoopCounts: {}, // Track loops per invocation: {invocationKey: loopCount} + invocationRuntimes: {}, // Track runtimes per invocation for stability: {invocationKey: [runtimes]} + stableInvocations: {}, // Invocations that have reached stability: {invocationKey: true} }; } const sharedPerfState = process[PERF_STATE_KEY]; @@ -265,26 +267,40 @@ const results = []; let db = null; /** - * Check if performance has stabilized (for internal looping). - * Matches Python's pytest_plugin.should_stop() logic. + * Check if performance has stabilized, allowing early stopping of benchmarks. + * Matches Python's pytest_plugin.should_stop() logic for consistency. + * + * Performance is considered stable when BOTH conditions are met: + * 1. CENTER: All recent measurements are within ±10% of the median + * 2. SPREAD: The range (max-min) is within 10% of the minimum + * + * @param {Array} runtimes - Array of runtime measurements in microseconds + * @param {number} window - Number of recent measurements to check + * @param {number} minWindowSize - Minimum samples required before checking + * @returns {boolean} True if performance has stabilized */ function shouldStopStability(runtimes, window, minWindowSize) { if (runtimes.length < window || runtimes.length < minWindowSize) { return false; } + const recent = runtimes.slice(-window); const recentSorted = [...recent].sort((a, b) => a - b); const mid = Math.floor(window / 2); const median = window % 2 ? recentSorted[mid] : (recentSorted[mid - 1] + recentSorted[mid]) / 2; + // Check CENTER: all recent points must be close to median for (const r of recent) { if (Math.abs(r - median) / median > STABILITY_CENTER_TOLERANCE) { return false; } } + + // Check SPREAD: range must be small relative to minimum const rMin = recentSorted[0]; const rMax = recentSorted[recentSorted.length - 1]; if (rMin === 0) return false; + return (rMax - rMin) / rMin <= STABILITY_SPREAD_TOLERANCE; } @@ -673,12 +689,26 @@ function capturePerf(funcName, lineId, fn, ...args) { ? (hasExternalLoopRunner ? getPerfBatchSize() : getPerfLoopCount()) : 1; + // Initialize runtime tracking for this invocation if needed + if (!sharedPerfState.invocationRuntimes[invocationKey]) { + sharedPerfState.invocationRuntimes[invocationKey] = []; + } + const runtimes = sharedPerfState.invocationRuntimes[invocationKey]; + + // Calculate stability window size based on collected runtimes + const getStabilityWindow = () => Math.max(getPerfMinLoops(), Math.ceil(runtimes.length * STABILITY_WINDOW_SIZE)); + for (let batchIndex = 0; batchIndex < batchSize; batchIndex++) { // Check shared time limit BEFORE each iteration if (shouldLoop && checkSharedTimeLimit()) { break; } + // Check if this invocation has already reached stability + if (getPerfStabilityCheck() && sharedPerfState.stableInvocations[invocationKey]) { + break; + } + // Get the global loop index for this invocation (increments across batches) const loopIndex = getInvocationLoopIndex(invocationKey); @@ -703,23 +733,17 @@ function capturePerf(funcName, lineId, fn, ...args) { const endTime = getTimeNs(); durationNs = getDurationNs(startTime, endTime); - // Handle promises - for async functions, run once and return + // Handle promises - for async functions, we need to handle looping differently + // Since we can't use await in the sync loop, delegate to async helper if (lastReturnValue instanceof Promise) { - return lastReturnValue.then( - (resolved) => { - const asyncEndTime = getTimeNs(); - const asyncDurationNs = getDurationNs(startTime, asyncEndTime); - console.log(`!######${testStdoutTag}:${asyncDurationNs}######!`); - sharedPerfState.totalLoopsCompleted++; - return resolved; - }, - (err) => { - const asyncEndTime = getTimeNs(); - const asyncDurationNs = getDurationNs(startTime, asyncEndTime); - console.log(`!######${testStdoutTag}:${asyncDurationNs}######!`); - sharedPerfState.totalLoopsCompleted++; - throw err; - } + // For async functions, delegate to the async looping helper + // Pass along all the context needed for continued looping + return _capturePerfAsync( + funcName, lineId, fn, args, + lastReturnValue, startTime, testStdoutTag, + safeModulePath, testClassName, safeTestFunctionName, + invocationKey, runtimes, batchSize, batchIndex, + shouldLoop, getStabilityWindow ); } @@ -735,6 +759,20 @@ function capturePerf(funcName, lineId, fn, ...args) { // Update shared loop counter sharedPerfState.totalLoopsCompleted++; + // Track runtime for stability check (convert to microseconds) + if (durationNs > 0) { + runtimes.push(durationNs / 1000); + } + + // Check stability after accumulating enough samples + if (getPerfStabilityCheck() && runtimes.length >= getPerfMinLoops()) { + const window = getStabilityWindow(); + if (shouldStopStability(runtimes, window, getPerfMinLoops())) { + sharedPerfState.stableInvocations[invocationKey] = true; + break; + } + } + // If we had an error, stop looping if (lastError) { break; @@ -751,6 +789,119 @@ function capturePerf(funcName, lineId, fn, ...args) { return lastReturnValue; } +/** + * Helper to record async timing and update state. + * @private + */ +function _recordAsyncTiming(startTime, testStdoutTag, durationNs, runtimes) { + console.log(`!######${testStdoutTag}:${durationNs}######!`); + sharedPerfState.totalLoopsCompleted++; + if (durationNs > 0) { + runtimes.push(durationNs / 1000); + } +} + +/** + * Async helper for capturePerf to handle async function looping. + * This function awaits promises and continues the benchmark loop properly. + * + * @private + * @param {string} funcName - Name of the function being benchmarked + * @param {string} lineId - Line identifier for this capture point + * @param {Function} fn - The async function to benchmark + * @param {Array} args - Arguments to pass to fn + * @param {Promise} firstPromise - The first promise that was already started + * @param {number} firstStartTime - Start time of the first execution + * @param {string} firstTestStdoutTag - Timing marker tag for the first execution + * @param {string} safeModulePath - Sanitized module path + * @param {string|null} testClassName - Test class name (if any) + * @param {string} safeTestFunctionName - Sanitized test function name + * @param {string} invocationKey - Unique key for this invocation + * @param {Array} runtimes - Array to collect runtimes for stability checking + * @param {number} batchSize - Number of iterations per batch + * @param {number} startBatchIndex - Index where async looping started + * @param {boolean} shouldLoop - Whether to continue looping + * @param {Function} getStabilityWindow - Function to get stability window size + * @returns {Promise} The last return value from fn + */ +async function _capturePerfAsync( + funcName, lineId, fn, args, + firstPromise, firstStartTime, firstTestStdoutTag, + safeModulePath, testClassName, safeTestFunctionName, + invocationKey, runtimes, batchSize, startBatchIndex, + shouldLoop, getStabilityWindow +) { + let lastReturnValue; + let lastError = null; + + // Handle the first promise that was already started + try { + lastReturnValue = await firstPromise; + const asyncEndTime = getTimeNs(); + const asyncDurationNs = getDurationNs(firstStartTime, asyncEndTime); + _recordAsyncTiming(firstStartTime, firstTestStdoutTag, asyncDurationNs, runtimes); + } catch (err) { + const asyncEndTime = getTimeNs(); + const asyncDurationNs = getDurationNs(firstStartTime, asyncEndTime); + _recordAsyncTiming(firstStartTime, firstTestStdoutTag, asyncDurationNs, runtimes); + lastError = err; + // Don't throw yet - we want to record the timing first + } + + // If first iteration failed, stop and throw + if (lastError) { + throw lastError; + } + + // Continue looping for remaining iterations + for (let batchIndex = startBatchIndex + 1; batchIndex < batchSize; batchIndex++) { + // Check exit conditions before starting next iteration + if (shouldLoop && checkSharedTimeLimit()) { + break; + } + + if (getPerfStabilityCheck() && sharedPerfState.stableInvocations[invocationKey]) { + break; + } + + const loopIndex = getInvocationLoopIndex(invocationKey); + if (loopIndex > getPerfLoopCount()) { + break; + } + + // Generate timing marker identifiers + const testId = `${safeModulePath}:${testClassName}:${safeTestFunctionName}:${lineId}:${loopIndex}`; + const invocationIndex = getInvocationIndex(testId); + const invocationId = `${lineId}_${invocationIndex}`; + const testStdoutTag = `${safeModulePath}:${testClassName ? testClassName + '.' : ''}${safeTestFunctionName}:${funcName}:${loopIndex}:${invocationId}`; + + // Execute and time the function + try { + const startTime = getTimeNs(); + lastReturnValue = await fn(...args); + const endTime = getTimeNs(); + const durationNs = getDurationNs(startTime, endTime); + + _recordAsyncTiming(startTime, testStdoutTag, durationNs, runtimes); + + // Check if we've reached performance stability + if (getPerfStabilityCheck() && runtimes.length >= getPerfMinLoops()) { + const window = getStabilityWindow(); + if (shouldStopStability(runtimes, window, getPerfMinLoops())) { + sharedPerfState.stableInvocations[invocationKey] = true; + break; + } + } + } catch (e) { + lastError = e; + break; + } + } + + if (lastError) throw lastError; + return lastReturnValue; +} + /** * Capture multiple invocations for benchmarking. * @@ -806,6 +957,8 @@ function resetPerfState() { sharedPerfState.startTime = null; sharedPerfState.totalLoopsCompleted = 0; sharedPerfState.shouldStop = false; + sharedPerfState.invocationRuntimes = {}; + sharedPerfState.stableInvocations = {}; } /** diff --git a/packages/codeflash/runtime/loop-runner.js b/packages/codeflash/runtime/loop-runner.js index 6bfde0c4c..33f9f7274 100644 --- a/packages/codeflash/runtime/loop-runner.js +++ b/packages/codeflash/runtime/loop-runner.js @@ -24,6 +24,8 @@ * NOTE: This runner requires jest-runner to be installed in your project. * It is a Jest-specific feature and does not work with Vitest. * For Vitest projects, capturePerf() does all loops internally in a single call. + * + * Compatibility: Works with Jest 29.x and Jest 30.x */ 'use strict'; @@ -32,10 +34,26 @@ const { createRequire } = require('module'); const path = require('path'); const fs = require('fs'); +/** + * Validates that a jest-runner path is valid by checking for package.json. + * @param {string} jestRunnerPath - Path to check + * @returns {boolean} True if valid jest-runner package + */ +function isValidJestRunnerPath(jestRunnerPath) { + if (!fs.existsSync(jestRunnerPath)) { + return false; + } + const packageJsonPath = path.join(jestRunnerPath, 'package.json'); + return fs.existsSync(packageJsonPath); +} + /** * Resolve jest-runner with monorepo support. * Uses CODEFLASH_MONOREPO_ROOT environment variable if available, * otherwise walks up the directory tree looking for node_modules/jest-runner. + * + * @returns {string} Path to jest-runner package + * @throws {Error} If jest-runner cannot be found */ function resolveJestRunner() { // Try standard resolution first (works in simple projects) @@ -49,11 +67,8 @@ function resolveJestRunner() { const monorepoRoot = process.env.CODEFLASH_MONOREPO_ROOT; if (monorepoRoot) { const jestRunnerPath = path.join(monorepoRoot, 'node_modules', 'jest-runner'); - if (fs.existsSync(jestRunnerPath)) { - const packageJsonPath = path.join(jestRunnerPath, 'package.json'); - if (fs.existsSync(packageJsonPath)) { - return jestRunnerPath; - } + if (isValidJestRunnerPath(jestRunnerPath)) { + return jestRunnerPath; } } @@ -69,11 +84,8 @@ function resolveJestRunner() { // Try node_modules/jest-runner at this level const jestRunnerPath = path.join(currentDir, 'node_modules', 'jest-runner'); - if (fs.existsSync(jestRunnerPath)) { - const packageJsonPath = path.join(jestRunnerPath, 'package.json'); - if (fs.existsSync(packageJsonPath)) { - return jestRunnerPath; - } + if (isValidJestRunnerPath(jestRunnerPath)) { + return jestRunnerPath; } // Check if this is a workspace root (has monorepo markers) @@ -89,18 +101,53 @@ function resolveJestRunner() { currentDir = path.dirname(currentDir); } - throw new Error('jest-runner not found'); + throw new Error( + 'jest-runner not found. Please install jest-runner in your project: npm install --save-dev jest-runner' + ); } -// Try to load jest-runner - it's a peer dependency that must be installed by the user +/** + * Jest runner components - loaded dynamically from project's node_modules. + * This ensures we use the same version that the project uses. + * + * Jest 30+ uses TestRunner class with event-based architecture. + * Jest 29 uses runTest function for direct test execution. + */ +let TestRunner; let runTest; let jestRunnerAvailable = false; +let jestVersion = 0; try { const jestRunnerPath = resolveJestRunner(); const internalRequire = createRequire(jestRunnerPath); - runTest = internalRequire('./runTest').default; - jestRunnerAvailable = true; + + // Try to get the TestRunner class (Jest 30+) + const jestRunner = internalRequire(jestRunnerPath); + TestRunner = jestRunner.default || jestRunner.TestRunner; + + if (TestRunner && TestRunner.prototype && typeof TestRunner.prototype.runTests === 'function') { + // Jest 30+ - use TestRunner class with event emitter pattern + jestVersion = 30; + jestRunnerAvailable = true; + } else { + // Try Jest 29 style import + try { + runTest = internalRequire('./runTest').default; + if (typeof runTest === 'function') { + // Jest 29 - use direct runTest function + jestVersion = 29; + jestRunnerAvailable = true; + } + } catch (e29) { + // Neither Jest 29 nor 30 style import worked + const errorMsg = `Found jest-runner at ${jestRunnerPath} but could not load it. ` + + `This may indicate an unsupported Jest version. ` + + `Supported versions: Jest 29.x and Jest 30.x`; + console.error(errorMsg); + jestRunnerAvailable = false; + } + } } catch (e) { // jest-runner not installed - this is expected for Vitest projects // The runner will throw a helpful error if someone tries to use it without jest-runner @@ -167,6 +214,9 @@ function deepCopy(obj, seen = new WeakMap()) { /** * Codeflash Loop Runner with Batched Looping + * + * For Jest 30+, extends the TestRunner class directly. + * For Jest 29, uses the runTest function import. */ class CodeflashLoopRunner { constructor(globalConfig, context) { @@ -175,12 +225,24 @@ class CodeflashLoopRunner { 'codeflash/loop-runner requires jest-runner to be installed.\n' + 'Please install it: npm install --save-dev jest-runner\n\n' + 'If you are using Vitest, the loop-runner is not needed - ' + - 'Vitest projects use external looping handled by the Python runner.' + 'Vitest projects use internal looping handled by capturePerf().' ); } + this._globalConfig = globalConfig; this._context = context || {}; this._eventEmitter = new SimpleEventEmitter(); + + // For Jest 30+, create an instance of the base TestRunner for delegation + if (jestVersion >= 30) { + if (!TestRunner) { + throw new Error( + `Jest ${jestVersion} detected but TestRunner class not available. ` + + `This indicates an internal error in loop-runner initialization.` + ); + } + this._baseRunner = new TestRunner(globalConfig, context); + } } get supportsEventEmitters() { @@ -196,7 +258,17 @@ class CodeflashLoopRunner { } /** - * Run tests with batched looping for fair distribution. + * Run tests with batched looping for fair distribution across all test invocations. + * + * This implements the batched looping strategy: + * Batch 1: Test1(N loops) → Test2(N loops) → Test3(N loops) + * Batch 2: Test1(N loops) → Test2(N loops) → Test3(N loops) + * ...until time budget exhausted or max batches reached + * + * @param {Array} tests - Jest test objects to run + * @param {Object} watcher - Jest watcher for interrupt handling + * @param {Object} options - Jest runner options + * @returns {Promise} */ async runTests(tests, watcher, options) { const startTime = Date.now(); @@ -204,29 +276,20 @@ class CodeflashLoopRunner { let hasFailure = false; let allConsoleOutput = ''; - // Import shared state functions from capture module - // We need to do this dynamically since the module may be reloaded - let checkSharedTimeLimit; - let incrementBatch; - try { - const capture = require('codeflash'); - checkSharedTimeLimit = capture.checkSharedTimeLimit; - incrementBatch = capture.incrementBatch; - } catch (e) { - // Fallback if codeflash module not available - checkSharedTimeLimit = () => { - const elapsed = Date.now() - startTime; - return elapsed >= TARGET_DURATION_MS && batchCount >= MIN_BATCHES; - }; - incrementBatch = () => {}; - } + // Time limit check - must use local time tracking because Jest runs tests + // in isolated worker processes where shared state from capture.js isn't accessible + const checkTimeLimit = () => { + const elapsed = Date.now() - startTime; + return elapsed >= TARGET_DURATION_MS && batchCount >= MIN_BATCHES; + }; // Batched looping: run all test files multiple times while (batchCount < MAX_BATCHES) { batchCount++; // Check time limit BEFORE each batch - if (batchCount > MIN_BATCHES && checkSharedTimeLimit()) { + if (batchCount > MIN_BATCHES && checkTimeLimit()) { + console.log(`[codeflash] Time limit reached after ${batchCount - 1} batches (${Date.now() - startTime}ms elapsed)`); break; } @@ -235,13 +298,11 @@ class CodeflashLoopRunner { break; } - // Increment batch counter in shared state and set env var - // The env var persists across Jest module resets, ensuring continuous loop indices - incrementBatch(); + // Set env var for batch number - persists across Jest module resets process.env.CODEFLASH_PERF_CURRENT_BATCH = String(batchCount); // Run all test files in this batch - const batchResult = await this._runAllTestsOnce(tests, watcher); + const batchResult = await this._runAllTestsOnce(tests, watcher, options); allConsoleOutput += batchResult.consoleOutput; if (batchResult.hasFailure) { @@ -250,7 +311,8 @@ class CodeflashLoopRunner { } // Check time limit AFTER each batch - if (checkSharedTimeLimit()) { + if (checkTimeLimit()) { + console.log(`[codeflash] Time limit reached after ${batchCount} batches (${Date.now() - startTime}ms elapsed)`); break; } } @@ -268,8 +330,74 @@ class CodeflashLoopRunner { /** * Run all test files once (one batch). + * Uses different approaches for Jest 29 vs Jest 30. + */ + async _runAllTestsOnce(tests, watcher, options) { + if (jestVersion >= 30) { + return this._runAllTestsOnceJest30(tests, watcher, options); + } else { + return this._runAllTestsOnceJest29(tests, watcher); + } + } + + /** + * Jest 30+ implementation - delegates to base TestRunner and collects results. + */ + async _runAllTestsOnceJest30(tests, watcher, options) { + let hasFailure = false; + let allConsoleOutput = ''; + + // For Jest 30, we need to collect results through event listeners + const resultsCollector = []; + + // Subscribe to events from the base runner + const unsubscribeSuccess = this._baseRunner.on('test-file-success', (testData) => { + const [test, result] = testData; + resultsCollector.push({ test, result, success: true }); + + if (result && result.console && Array.isArray(result.console)) { + allConsoleOutput += result.console.map(e => e.message || '').join('\n') + '\n'; + } + + if (result && result.numFailingTests > 0) { + hasFailure = true; + } + + // Forward to our event emitter + this._eventEmitter.emit('test-file-success', testData); + }); + + const unsubscribeFailure = this._baseRunner.on('test-file-failure', (testData) => { + const [test, error] = testData; + resultsCollector.push({ test, error, success: false }); + hasFailure = true; + + // Forward to our event emitter + this._eventEmitter.emit('test-file-failure', testData); + }); + + const unsubscribeStart = this._baseRunner.on('test-file-start', (testData) => { + // Forward to our event emitter + this._eventEmitter.emit('test-file-start', testData); + }); + + try { + // Run tests using the base runner (always serial for benchmarking) + await this._baseRunner.runTests(tests, watcher, { ...options, serial: true }); + } finally { + // Cleanup subscriptions + if (typeof unsubscribeSuccess === 'function') unsubscribeSuccess(); + if (typeof unsubscribeFailure === 'function') unsubscribeFailure(); + if (typeof unsubscribeStart === 'function') unsubscribeStart(); + } + + return { consoleOutput: allConsoleOutput, hasFailure }; + } + + /** + * Jest 29 implementation - uses direct runTest import. */ - async _runAllTestsOnce(tests, watcher) { + async _runAllTestsOnceJest29(tests, watcher) { let hasFailure = false; let allConsoleOutput = ''; diff --git a/tests/test_javascript_function_discovery.py b/tests/test_javascript_function_discovery.py index 9a39086a8..cf76bee2d 100644 --- a/tests/test_javascript_function_discovery.py +++ b/tests/test_javascript_function_discovery.py @@ -23,7 +23,7 @@ def test_simple_function_discovery(self, tmp_path): """Test discovering a simple JavaScript function with return statement.""" js_file = tmp_path / "simple.js" js_file.write_text(""" -function add(a, b) { +export function add(a, b) { return a + b; } """) @@ -39,15 +39,15 @@ def test_multiple_functions_discovery(self, tmp_path): """Test discovering multiple JavaScript functions.""" js_file = tmp_path / "multiple.js" js_file.write_text(""" -function add(a, b) { +export function add(a, b) { return a + b; } -function multiply(a, b) { +export function multiply(a, b) { return a * b; } -function divide(a, b) { +export function divide(a, b) { return a / b; } """) @@ -61,11 +61,11 @@ def test_function_without_return_excluded(self, tmp_path): """Test that functions without return statements are excluded.""" js_file = tmp_path / "no_return.js" js_file.write_text(""" -function withReturn() { +export function withReturn() { return 42; } -function withoutReturn() { +export function withoutReturn() { console.log("hello"); } """) @@ -78,11 +78,11 @@ def test_arrow_function_discovery(self, tmp_path): """Test discovering arrow functions with explicit return.""" js_file = tmp_path / "arrow.js" js_file.write_text(""" -const add = (a, b) => { +export const add = (a, b) => { return a + b; }; -const multiply = (a, b) => a * b; +export const multiply = (a, b) => a * b; """) functions = find_all_functions_in_file(js_file) @@ -95,7 +95,7 @@ def test_class_method_discovery(self, tmp_path): """Test discovering methods inside a JavaScript class.""" js_file = tmp_path / "class.js" js_file.write_text(""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } @@ -120,11 +120,11 @@ def test_async_function_discovery(self, tmp_path): """Test discovering async JavaScript functions.""" js_file = tmp_path / "async.js" js_file.write_text(""" -async function fetchData(url) { +export async function fetchData(url) { return await fetch(url); } -function syncFunc() { +export function syncFunc() { return 42; } """) @@ -141,7 +141,7 @@ def test_nested_function_excluded(self, tmp_path): """Test that nested functions are handled correctly.""" js_file = tmp_path / "nested.js" js_file.write_text(""" -function outer() { +export function outer() { function inner() { return 1; } @@ -158,11 +158,11 @@ def test_jsx_file_discovery(self, tmp_path): """Test discovering functions in JSX files.""" jsx_file = tmp_path / "component.jsx" jsx_file.write_text(""" -function Button({ onClick }) { +export function Button({ onClick }) { return ; } -function formatText(text) { +export function formatText(text) { return text.toUpperCase(); } """) @@ -176,7 +176,7 @@ def test_invalid_javascript_returns_empty(self, tmp_path): """Test that invalid JavaScript code returns empty results.""" js_file = tmp_path / "invalid.js" js_file.write_text(""" -function broken( { +export function broken( { return 42; } """) @@ -189,11 +189,11 @@ def test_function_line_numbers(self, tmp_path): """Test that function line numbers are correctly detected.""" js_file = tmp_path / "lines.js" js_file.write_text(""" -function firstFunc() { +export function firstFunc() { return 1; } -function secondFunc() { +export function secondFunc() { return 2; } """) @@ -217,7 +217,7 @@ def test_filter_functions_includes_javascript(self, tmp_path): """Test that filter_functions correctly includes JavaScript files.""" js_file = tmp_path / "module.js" js_file.write_text(""" -function add(a, b) { +export function add(a, b) { return a + b; } """) @@ -240,7 +240,7 @@ def test_filter_excludes_test_directory(self, tmp_path): tests_dir.mkdir() test_file = tests_dir / "test_module.test.js" test_file.write_text(""" -function testHelper() { +export function testHelper() { return 42; } """) @@ -260,7 +260,7 @@ def test_filter_excludes_ignored_paths(self, tmp_path): ignored_dir.mkdir() js_file = ignored_dir / "ignored_module.js" js_file.write_text(""" -function ignoredFunc() { +export function ignoredFunc() { return 42; } """) @@ -282,7 +282,7 @@ def test_filter_includes_files_with_dashes(self, tmp_path): """Test that JavaScript files with dashes in name are included (unlike Python).""" js_file = tmp_path / "my-module.js" js_file.write_text(""" -function myFunc() { +export function myFunc() { return 42; } """) @@ -312,11 +312,11 @@ def test_get_functions_from_file(self, tmp_path): """Test getting functions to optimize from a JavaScript file.""" js_file = tmp_path / "string_utils.js" js_file.write_text(""" -function reverseString(str) { +export function reverseString(str) { return str.split('').reverse().join(''); } -function capitalize(str) { +export function capitalize(str) { return str.charAt(0).toUpperCase() + str.slice(1); } """) @@ -422,12 +422,12 @@ def test_discover_all_js_functions(self, tmp_path): """Test discovering all JavaScript functions in a directory.""" # Create multiple JS files (tmp_path / "math.js").write_text(""" -function add(a, b) { +export function add(a, b) { return a + b; } """) (tmp_path / "string.js").write_text(""" -function reverse(str) { +export function reverse(str) { return str.split('').reverse().join(''); } """) @@ -451,7 +451,7 @@ def py_func(): return 1 """) (tmp_path / "js_module.js").write_text(""" -function jsFunc() { +export function jsFunc() { return 1; } """) @@ -476,7 +476,7 @@ def test_qualified_name_no_parents(self, tmp_path): """Test qualified name for top-level function.""" js_file = tmp_path / "module.js" js_file.write_text(""" -function topLevel() { +export function topLevel() { return 42; } """) @@ -490,7 +490,7 @@ def test_qualified_name_with_class_parent(self, tmp_path): """Test qualified name for class method.""" js_file = tmp_path / "module.js" js_file.write_text(""" -class MyClass { +export class MyClass { myMethod() { return 42; } @@ -506,7 +506,7 @@ def test_language_attribute(self, tmp_path): """Test that JavaScript functions have correct language attribute.""" js_file = tmp_path / "module.js" js_file.write_text(""" -function myFunc() { +export function myFunc() { return 42; } """) diff --git a/tests/test_languages/fixtures/js_cjs/calculator.js b/tests/test_languages/fixtures/js_cjs/calculator.js index 6a75d8476..8176c0007 100644 --- a/tests/test_languages/fixtures/js_cjs/calculator.js +++ b/tests/test_languages/fixtures/js_cjs/calculator.js @@ -6,7 +6,7 @@ const { add, multiply, factorial } = require('./math_utils'); const { formatNumber, validateInput } = require('./helpers/format'); -class Calculator { +export class Calculator { constructor(precision = 2) { this.precision = precision; this.history = []; diff --git a/tests/test_languages/fixtures/js_cjs/helpers/format.js b/tests/test_languages/fixtures/js_cjs/helpers/format.js index d2d50e4df..15dae5e1c 100644 --- a/tests/test_languages/fixtures/js_cjs/helpers/format.js +++ b/tests/test_languages/fixtures/js_cjs/helpers/format.js @@ -8,7 +8,7 @@ * @param decimals - Number of decimal places * @returns Formatted number */ -function formatNumber(num, decimals) { +export function formatNumber(num, decimals) { return Number(num.toFixed(decimals)); } @@ -18,7 +18,7 @@ function formatNumber(num, decimals) { * @param name - Parameter name for error message * @throws Error if value is not a valid number */ -function validateInput(value, name) { +export function validateInput(value, name) { if (typeof value !== 'number' || isNaN(value)) { throw new Error(`Invalid ${name}: must be a number`); } @@ -30,7 +30,7 @@ function validateInput(value, name) { * @param symbol - Currency symbol * @returns Formatted currency string */ -function formatCurrency(amount, symbol = '$') { +export function formatCurrency(amount, symbol = '$') { return `${symbol}${formatNumber(amount, 2)}`; } diff --git a/tests/test_languages/fixtures/js_cjs/math_utils.js b/tests/test_languages/fixtures/js_cjs/math_utils.js index 0b650ed0e..a09a4e880 100644 --- a/tests/test_languages/fixtures/js_cjs/math_utils.js +++ b/tests/test_languages/fixtures/js_cjs/math_utils.js @@ -8,7 +8,7 @@ * @param b - Second number * @returns Sum of a and b */ -function add(a, b) { +export function add(a, b) { return a + b; } @@ -18,7 +18,7 @@ function add(a, b) { * @param b - Second number * @returns Product of a and b */ -function multiply(a, b) { +export function multiply(a, b) { return a * b; } @@ -27,7 +27,7 @@ function multiply(a, b) { * @param n - Non-negative integer * @returns Factorial of n */ -function factorial(n) { +export function factorial(n) { // Intentionally inefficient recursive implementation if (n <= 1) return 1; return n * factorial(n - 1); @@ -39,7 +39,7 @@ function factorial(n) { * @param exp - Exponent * @returns base raised to exp */ -function power(base, exp) { +export function power(base, exp) { // Inefficient: linear time instead of log time let result = 1; for (let i = 0; i < exp; i++) { diff --git a/tests/test_languages/test_code_context_extraction.py b/tests/test_languages/test_code_context_extraction.py index 87c728b34..07946ddd3 100644 --- a/tests/test_languages/test_code_context_extraction.py +++ b/tests/test_languages/test_code_context_extraction.py @@ -56,7 +56,7 @@ class TestSimpleFunctionContext: def test_simple_function_no_dependencies(self, js_support, temp_project): """Test extracting context for a simple standalone function without any dependencies.""" code = """\ -function add(a, b) { +export function add(a, b) { return a + b; } """ @@ -70,7 +70,7 @@ def test_simple_function_no_dependencies(self, js_support, temp_project): context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -function add(a, b) { +export function add(a, b) { return a + b; } """ @@ -84,7 +84,7 @@ def test_simple_function_no_dependencies(self, js_support, temp_project): def test_arrow_function_with_implicit_return(self, js_support, temp_project): """Test extracting an arrow function with implicit return.""" code = """\ -const multiply = (a, b) => a * b; +export const multiply = (a, b) => a * b; """ file_path = temp_project / "math.js" file_path.write_text(code, encoding="utf-8") @@ -97,7 +97,7 @@ def test_arrow_function_with_implicit_return(self, js_support, temp_project): context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -const multiply = (a, b) => a * b; +export const multiply = (a, b) => a * b; """ assert context.target_code == expected_target_code assert context.helper_functions == [] @@ -116,7 +116,7 @@ def test_function_with_simple_jsdoc(self, js_support, temp_project): * @param {number} b - Second number * @returns {number} The sum */ -function add(a, b) { +export function add(a, b) { return a + b; } """ @@ -129,13 +129,7 @@ def test_function_with_simple_jsdoc(self, js_support, temp_project): context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -/** - * Adds two numbers together. - * @param {number} a - First number - * @param {number} b - Second number - * @returns {number} The sum - */ -function add(a, b) { +export function add(a, b) { return a + b; } """ @@ -163,7 +157,7 @@ def test_function_with_complex_jsdoc_types(self, js_support, temp_project): * const doubled = await processItems([1, 2, 3], x => x * 2); * // returns [2, 4, 6] */ -async function processItems(items, callback, options = {}) { +export async function processItems(items, callback, options = {}) { const { parallel = false, chunkSize = 100 } = options; if (!Array.isArray(items)) { @@ -187,25 +181,7 @@ def test_function_with_complex_jsdoc_types(self, js_support, temp_project): context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -/** - * Processes an array of items with a callback function. - * - * This function iterates over each item and applies the transformation. - * - * @template T - The type of items in the input array - * @template U - The type of items in the output array - * @param {Array} items - The input array to process - * @param {function(T, number): U} callback - Transformation function - * @param {Object} [options] - Optional configuration - * @param {boolean} [options.parallel=false] - Whether to process in parallel - * @param {number} [options.chunkSize=100] - Size of processing chunks - * @returns {Promise>} The transformed array - * @throws {TypeError} If items is not an array - * @example - * const doubled = await processItems([1, 2, 3], x => x * 2); - * // returns [2, 4, 6] - */ -async function processItems(items, callback, options = {}) { +export async function processItems(items, callback, options = {}) { const { parallel = false, chunkSize = 100 } = options; if (!Array.isArray(items)) { @@ -231,7 +207,7 @@ def test_class_with_jsdoc_on_class_and_methods(self, js_support, temp_project): * @class CacheManager * @description Provides in-memory caching with automatic expiration. */ -class CacheManager { +export class CacheManager { /** * Creates a new cache manager. * @param {number} defaultTTL - Default time-to-live in milliseconds @@ -275,12 +251,6 @@ class CacheManager { context = js_support.extract_code_context(get_or_compute, temp_project, temp_project) expected_target_code = """\ -/** - * A cache implementation with TTL support. - * - * @class CacheManager - * @description Provides in-memory caching with automatic expiration. - */ class CacheManager { /** * Creates a new cache manager. @@ -344,7 +314,7 @@ def test_jsdoc_with_typedef_and_callback(self, js_support, temp_project): * @param {ValidatorFunction[]} validators - Array of validator functions * @returns {ValidationResult} Combined validation result */ -function validateUserData(data, validators) { +export function validateUserData(data, validators) { const errors = []; const fieldErrors = {}; @@ -377,13 +347,7 @@ def test_jsdoc_with_typedef_and_callback(self, js_support, temp_project): context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -/** - * Validates user input data. - * @param {Object} data - The data to validate - * @param {ValidatorFunction[]} validators - Array of validator functions - * @returns {ValidationResult} Combined validation result - */ -function validateUserData(data, validators) { +export function validateUserData(data, validators) { const errors = []; const fieldErrors = {}; @@ -433,7 +397,7 @@ def test_function_with_multiple_complex_constants(self, js_support, temp_project }; const UNUSED_CONFIG = { debug: false }; -async function fetchWithRetry(endpoint, options = {}) { +export async function fetchWithRetry(endpoint, options = {}) { const url = API_BASE_URL + endpoint; let lastError; @@ -473,7 +437,7 @@ def test_function_with_multiple_complex_constants(self, js_support, temp_project context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -async function fetchWithRetry(endpoint, options = {}) { +export async function fetchWithRetry(endpoint, options = {}) { const url = API_BASE_URL + endpoint; let lastError; @@ -537,7 +501,7 @@ def test_function_with_regex_and_template_constants(self, js_support, temp_proje url: 'Please enter a valid URL' }; -function validateField(value, fieldType) { +export function validateField(value, fieldType) { const pattern = PATTERNS[fieldType]; if (!pattern) { return { valid: true, error: null }; @@ -559,7 +523,7 @@ def test_function_with_regex_and_template_constants(self, js_support, temp_proje context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -function validateField(value, fieldType) { +export function validateField(value, fieldType) { const pattern = PATTERNS[fieldType]; if (!pattern) { return { valid: true, error: null }; @@ -595,16 +559,16 @@ class TestSameFileHelperFunctions: def test_function_with_chain_of_helpers(self, js_support, temp_project): """Test function calling helper that calls another helper (transitive dependencies).""" code = """\ -function sanitizeString(str) { +export function sanitizeString(str) { return str.trim().toLowerCase(); } -function normalizeInput(input) { +export function normalizeInput(input) { const sanitized = sanitizeString(input); return sanitized.replace(/\\s+/g, '-'); } -function processUserInput(rawInput) { +export function processUserInput(rawInput) { const normalized = normalizeInput(rawInput); return { original: rawInput, @@ -622,7 +586,7 @@ def test_function_with_chain_of_helpers(self, js_support, temp_project): context = js_support.extract_code_context(process_func, temp_project, temp_project) expected_target_code = """\ -function processUserInput(rawInput) { +export function processUserInput(rawInput) { const normalized = normalizeInput(rawInput); return { original: rawInput, @@ -640,23 +604,23 @@ def test_function_with_chain_of_helpers(self, js_support, temp_project): def test_function_with_multiple_unrelated_helpers(self, js_support, temp_project): """Test function calling multiple independent helper functions.""" code = """\ -function formatDate(date) { +export function formatDate(date) { return date.toISOString().split('T')[0]; } -function formatCurrency(amount) { +export function formatCurrency(amount) { return '$' + amount.toFixed(2); } -function formatPercentage(value) { +export function formatPercentage(value) { return (value * 100).toFixed(1) + '%'; } -function unusedFormatter() { +export function unusedFormatter() { return 'not used'; } -function generateReport(data) { +export function generateReport(data) { const date = formatDate(new Date(data.timestamp)); const revenue = formatCurrency(data.revenue); const growth = formatPercentage(data.growth); @@ -677,7 +641,7 @@ def test_function_with_multiple_unrelated_helpers(self, js_support, temp_project context = js_support.extract_code_context(report_func, temp_project, temp_project) expected_target_code = """\ -function generateReport(data) { +export function generateReport(data) { const date = formatDate(new Date(data.timestamp)); const revenue = formatCurrency(data.revenue); const growth = formatPercentage(data.growth); @@ -699,21 +663,21 @@ def test_function_with_multiple_unrelated_helpers(self, js_support, temp_project for helper in context.helper_functions: if helper.name == "formatDate": expected = """\ -function formatDate(date) { +export function formatDate(date) { return date.toISOString().split('T')[0]; } """ assert helper.source_code == expected elif helper.name == "formatCurrency": expected = """\ -function formatCurrency(amount) { +export function formatCurrency(amount) { return '$' + amount.toFixed(2); } """ assert helper.source_code == expected elif helper.name == "formatPercentage": expected = """\ -function formatPercentage(value) { +export function formatPercentage(value) { return (value * 100).toFixed(1) + '%'; } """ @@ -726,7 +690,7 @@ class TestClassMethodWithSiblingMethods: def test_graph_topological_sort(self, js_support, temp_project): """Test graph class with topological sort - similar to Python test_class_method_dependencies.""" code = """\ -class Graph { +export class Graph { constructor(vertices) { this.graph = new Map(); this.V = vertices; @@ -774,7 +738,7 @@ class Graph { context = js_support.extract_code_context(topo_sort, temp_project, temp_project) - # The extracted code should include class wrapper with constructor + # The extracted code should include class wrapper with constructor and sibling methods used expected_target_code = """\ class Graph { constructor(vertices) { @@ -794,6 +758,19 @@ class Graph { return stack; } + + topologicalSortUtil(v, visited, stack) { + visited[v] = true; + + const neighbors = this.graph.get(v) || []; + for (const i of neighbors) { + if (visited[i] === false) { + this.topologicalSortUtil(i, visited, stack); + } + } + + stack.unshift(v); + } } """ assert context.target_code == expected_target_code @@ -802,7 +779,7 @@ class Graph { def test_class_method_using_nested_helper_class(self, js_support, temp_project): """Test class method that uses another class as a helper - mirrors Python HelperClass test.""" code = """\ -class HelperClass { +export class HelperClass { constructor(name) { this.name = name; } @@ -816,7 +793,7 @@ class HelperClass { } } -class NestedHelper { +export class NestedHelper { constructor(name) { this.name = name; } @@ -826,11 +803,11 @@ class NestedHelper { } } -function mainMethod() { +export function mainMethod() { return 'hello'; } -class MainClass { +export class MainClass { constructor(name) { this.name = name; } @@ -890,7 +867,7 @@ def test_helper_from_another_file_commonjs(self, js_support, temp_project): main_code = """\ const { sorter } = require('./bubble_sort_with_math'); -function sortFromAnotherFile(arr) { +export function sortFromAnotherFile(arr) { const sortedArr = sorter(arr); return sortedArr; } @@ -906,7 +883,7 @@ def test_helper_from_another_file_commonjs(self, js_support, temp_project): context = js_support.extract_code_context(main_func, temp_project, temp_project) expected_target_code = """\ -function sortFromAnotherFile(arr) { +export function sortFromAnotherFile(arr) { const sortedArr = sorter(arr); return sortedArr; } @@ -943,12 +920,10 @@ def test_helper_from_another_file_esm(self, js_support, temp_project): main_code = """\ import identity, { double, triple } from './utils'; -function processNumber(n) { +export function processNumber(n) { const base = identity(n); return double(base) + triple(base); } - -export { processNumber }; """ main_path = temp_project / "main.js" main_path.write_text(main_code, encoding="utf-8") @@ -959,7 +934,7 @@ def test_helper_from_another_file_esm(self, js_support, temp_project): context = js_support.extract_code_context(process_func, temp_project, temp_project) expected_target_code = """\ -function processNumber(n) { +export function processNumber(n) { const base = identity(n); return double(base) + triple(base); } @@ -1007,7 +982,7 @@ def test_chained_imports_across_three_files(self, js_support, temp_project): main_code = """\ import { transformInput } from './middleware'; -function handleUserInput(rawInput) { +export function handleUserInput(rawInput) { try { const result = transformInput(rawInput); return { success: true, data: result }; @@ -1015,8 +990,6 @@ def test_chained_imports_across_three_files(self, js_support, temp_project): return { success: false, error: error.message }; } } - -export { handleUserInput }; """ main_path = temp_project / "main.js" main_path.write_text(main_code, encoding="utf-8") @@ -1027,7 +1000,7 @@ def test_chained_imports_across_three_files(self, js_support, temp_project): context = js_support.extract_code_context(handle_func, temp_project, temp_project) expected_target_code = """\ -function handleUserInput(rawInput) { +export function handleUserInput(rawInput) { try { const result = transformInput(rawInput); return { success: true, data: result }; @@ -1059,7 +1032,7 @@ def test_function_with_complex_generic_types(self, ts_support, temp_project): type Entity = T & Identifiable & Timestamped; -function createEntity(data: T): Entity { +export function createEntity(data: T): Entity { const now = new Date(); return { ...data, @@ -1078,7 +1051,7 @@ def test_function_with_complex_generic_types(self, ts_support, temp_project): context = ts_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -function createEntity(data: T): Entity { +export function createEntity(data: T): Entity { const now = new Date(); return { ...data, @@ -1117,7 +1090,7 @@ def test_class_with_private_fields_and_typed_methods(self, ts_support, temp_proj maxSize: number; } -class TypedCache { +export class TypedCache { private readonly cache: Map>; private readonly config: CacheConfig; @@ -1235,15 +1208,13 @@ def test_typescript_with_type_imports(self, ts_support, temp_project): const DEFAULT_ROLE: UserRole = 'user'; -function createUser(input: CreateUserInput, role: UserRole = DEFAULT_ROLE): User { +export function createUser(input: CreateUserInput, role: UserRole = DEFAULT_ROLE): User { return { id: Math.random().toString(36).substring(2), name: input.name, email: input.email }; } - -export { createUser }; """ service_path = temp_project / "service.ts" service_path.write_text(service_code, encoding="utf-8") @@ -1254,7 +1225,7 @@ def test_typescript_with_type_imports(self, ts_support, temp_project): context = ts_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -function createUser(input: CreateUserInput, role: UserRole = DEFAULT_ROLE): User { +export function createUser(input: CreateUserInput, role: UserRole = DEFAULT_ROLE): User { return { id: Math.random().toString(36).substring(2), name: input.name, @@ -1294,7 +1265,7 @@ class TestRecursionAndCircularDependencies: def test_self_recursive_factorial(self, js_support, temp_project): """Test self-recursive function does not list itself as helper.""" code = """\ -function factorial(n) { +export function factorial(n) { if (n <= 1) return 1; return n * factorial(n - 1); } @@ -1308,7 +1279,7 @@ def test_self_recursive_factorial(self, js_support, temp_project): context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -function factorial(n) { +export function factorial(n) { if (n <= 1) return 1; return n * factorial(n - 1); } @@ -1319,12 +1290,12 @@ def test_self_recursive_factorial(self, js_support, temp_project): def test_mutually_recursive_even_odd(self, js_support, temp_project): """Test mutually recursive functions.""" code = """\ -function isEven(n) { +export function isEven(n) { if (n === 0) return true; return isOdd(n - 1); } -function isOdd(n) { +export function isOdd(n) { if (n === 0) return false; return isEven(n - 1); } @@ -1338,7 +1309,7 @@ def test_mutually_recursive_even_odd(self, js_support, temp_project): context = js_support.extract_code_context(is_even, temp_project, temp_project) expected_target_code = """\ -function isEven(n) { +export function isEven(n) { if (n === 0) return true; return isOdd(n - 1); } @@ -1351,7 +1322,7 @@ def test_mutually_recursive_even_odd(self, js_support, temp_project): # Verify helper source assert context.helper_functions[0].source_code == """\ -function isOdd(n) { +export function isOdd(n) { if (n === 0) return false; return isEven(n - 1); } @@ -1360,28 +1331,28 @@ def test_mutually_recursive_even_odd(self, js_support, temp_project): def test_complex_recursive_tree_traversal(self, js_support, temp_project): """Test complex recursive tree traversal with multiple recursive calls.""" code = """\ -function traversePreOrder(node, visit) { +export function traversePreOrder(node, visit) { if (!node) return; visit(node.value); traversePreOrder(node.left, visit); traversePreOrder(node.right, visit); } -function traverseInOrder(node, visit) { +export function traverseInOrder(node, visit) { if (!node) return; traverseInOrder(node.left, visit); visit(node.value); traverseInOrder(node.right, visit); } -function traversePostOrder(node, visit) { +export function traversePostOrder(node, visit) { if (!node) return; traversePostOrder(node.left, visit); traversePostOrder(node.right, visit); visit(node.value); } -function collectAllValues(root) { +export function collectAllValues(root) { const values = { pre: [], in: [], post: [] }; traversePreOrder(root, v => values.pre.push(v)); @@ -1400,7 +1371,7 @@ def test_complex_recursive_tree_traversal(self, js_support, temp_project): context = js_support.extract_code_context(collect_func, temp_project, temp_project) expected_target_code = """\ -function collectAllValues(root) { +export function collectAllValues(root) { const values = { pre: [], in: [], post: [] }; traversePreOrder(root, v => values.pre.push(v)); @@ -1423,7 +1394,7 @@ class TestAsyncPatternsAndPromises: def test_async_function_chain(self, js_support, temp_project): """Test async function that calls other async functions.""" code = """\ -async function fetchUserById(id) { +export async function fetchUserById(id) { const response = await fetch(`/api/users/${id}`); if (!response.ok) { throw new Error(`User ${id} not found`); @@ -1431,17 +1402,17 @@ def test_async_function_chain(self, js_support, temp_project): return response.json(); } -async function fetchUserPosts(userId) { +export async function fetchUserPosts(userId) { const response = await fetch(`/api/users/${userId}/posts`); return response.json(); } -async function fetchUserComments(userId) { +export async function fetchUserComments(userId) { const response = await fetch(`/api/users/${userId}/comments`); return response.json(); } -async function fetchUserProfile(userId) { +export async function fetchUserProfile(userId) { const user = await fetchUserById(userId); const [posts, comments] = await Promise.all([ fetchUserPosts(userId), @@ -1465,7 +1436,7 @@ def test_async_function_chain(self, js_support, temp_project): context = js_support.extract_code_context(profile_func, temp_project, temp_project) expected_target_code = """\ -async function fetchUserProfile(userId) { +export async function fetchUserProfile(userId) { const user = await fetchUserById(userId); const [posts, comments] = await Promise.all([ fetchUserPosts(userId), @@ -1493,7 +1464,7 @@ class TestExtractionReplacementRoundTrip: def test_extract_and_replace_class_method(self, js_support, temp_project): """Test extracting code context and then replacing the method.""" original_source = """\ -class Counter { +export class Counter { constructor(initial = 0) { this.count = initial; } @@ -1536,7 +1507,7 @@ class Counter { # Step 2: Simulate AI returning optimized code optimized_code_from_ai = """\ -class Counter { +export class Counter { constructor(initial = 0) { this.count = initial; } @@ -1551,7 +1522,7 @@ class Counter { result = js_support.replace_function(original_source, increment_func, optimized_code_from_ai) expected_result = """\ -class Counter { +export class Counter { constructor(initial = 0) { this.count = initial; } @@ -1578,7 +1549,7 @@ class TestEdgeCases: def test_function_with_complex_destructuring(self, js_support, temp_project): """Test function with complex nested destructuring parameters.""" code = """\ -function processApiResponse({ +export function processApiResponse({ data: { users = [], meta: { total, page } = {} } = {}, status, headers: { 'content-type': contentType } = {} @@ -1600,7 +1571,7 @@ def test_function_with_complex_destructuring(self, js_support, temp_project): context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -function processApiResponse({ +export function processApiResponse({ data: { users = [], meta: { total, page } = {} } = {}, status, headers: { 'content-type': contentType } = {} @@ -1619,13 +1590,13 @@ def test_function_with_complex_destructuring(self, js_support, temp_project): def test_generator_function(self, js_support, temp_project): """Test generator function extraction.""" code = """\ -function* range(start, end, step = 1) { +export function* range(start, end, step = 1) { for (let i = start; i < end; i += step) { yield i; } } -function* fibonacci(limit) { +export function* fibonacci(limit) { let [a, b] = [0, 1]; while (a < limit) { yield a; @@ -1642,7 +1613,7 @@ def test_generator_function(self, js_support, temp_project): context = js_support.extract_code_context(range_func, temp_project, temp_project) expected_target_code = """\ -function* range(start, end, step = 1) { +export function* range(start, end, step = 1) { for (let i = start; i < end; i += step) { yield i; } @@ -1660,7 +1631,7 @@ def test_function_with_computed_property_names(self, js_support, temp_project): AGE: 'user_age' }; -function createUserObject(name, email, age) { +export function createUserObject(name, email, age) { return { [FIELD_KEYS.NAME]: name, [FIELD_KEYS.EMAIL]: email, @@ -1677,7 +1648,7 @@ def test_function_with_computed_property_names(self, js_support, temp_project): context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -function createUserObject(name, email, age) { +export function createUserObject(name, email, age) { return { [FIELD_KEYS.NAME]: name, [FIELD_KEYS.EMAIL]: email, @@ -1937,7 +1908,7 @@ class TestContextProperties: def test_javascript_context_has_correct_language(self, js_support, temp_project): """Test that JavaScript context has correct language property.""" code = """\ -function test() { +export function test() { return 1; } """ @@ -1956,7 +1927,7 @@ def test_javascript_context_has_correct_language(self, js_support, temp_project) def test_typescript_context_has_javascript_language(self, ts_support, temp_project): """Test that TypeScript context uses JavaScript language enum.""" code = """\ -function test(): number { +export function test(): number { return 1; } """ @@ -1977,7 +1948,7 @@ class TestContextValidation: def test_all_class_methods_produce_valid_syntax(self, js_support, temp_project): """Test that all extracted class methods are syntactically valid JavaScript.""" code = """\ -class Calculator { +export class Calculator { constructor(precision = 2) { this.precision = precision; } diff --git a/tests/test_languages/test_function_discovery_integration.py b/tests/test_languages/test_function_discovery_integration.py index 621a00d79..c91f91fe5 100644 --- a/tests/test_languages/test_function_discovery_integration.py +++ b/tests/test_languages/test_function_discovery_integration.py @@ -89,11 +89,11 @@ def test_javascript_file_routes_to_js_handler(self): """Test that JavaScript files use the JavaScript handler.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -function add(a, b) { +export function add(a, b) { return a + b; } -function multiply(a, b) { +export function multiply(a, b) { return a * b; } """) @@ -124,7 +124,7 @@ def test_function_to_optimize_has_correct_fields(self): """Test that FunctionToOptimize has all required fields populated.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } @@ -162,7 +162,7 @@ def add(a, b): def test_discovers_javascript_files_when_specified(self, tmp_path): """Test that JavaScript files are discovered when language is specified.""" (tmp_path / "module.js").write_text(""" -function add(a, b) { +export function add(a, b) { return a + b; } """) @@ -177,7 +177,7 @@ def py_func(): return 1 """) (tmp_path / "js_module.js").write_text(""" -function jsFunc() { +export function jsFunc() { return 1; } """) diff --git a/tests/test_languages/test_javascript_e2e.py b/tests/test_languages/test_javascript_e2e.py index 2fe25c18a..ae268def5 100644 --- a/tests/test_languages/test_javascript_e2e.py +++ b/tests/test_languages/test_javascript_e2e.py @@ -129,13 +129,7 @@ def test_extract_code_context_for_javascript(self, js_project_dir): assert len(context.read_writable_code.code_strings) > 0 code = context.read_writable_code.code_strings[0].code - expected_code = """/** - * Calculate the nth Fibonacci number using naive recursion. - * This is intentionally slow to demonstrate optimization potential. - * @param {number} n - The index of the Fibonacci number to calculate - * @returns {number} - The nth Fibonacci number - */ -function fibonacci(n) { + expected_code = """export function fibonacci(n) { if (n <= 1) { return n; } @@ -155,16 +149,16 @@ def test_replace_function_in_javascript_file(self): from codeflash.languages.base import FunctionInfo original_source = """ -function add(a, b) { +export function add(a, b) { return a + b; } -function multiply(a, b) { +export function multiply(a, b) { return a * b; } """ - new_function = """function add(a, b) { + new_function = """export function add(a, b) { // Optimized version return a + b; }""" @@ -178,12 +172,12 @@ def test_replace_function_in_javascript_file(self): result = js_support.replace_function(original_source, func_info, new_function) expected_result = """ -function add(a, b) { +export function add(a, b) { // Optimized version return a + b; } -function multiply(a, b) { +export function multiply(a, b) { return a * b; } """ @@ -234,7 +228,7 @@ def test_function_to_optimize_has_correct_fields(self): with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } @@ -244,7 +238,7 @@ class Calculator { } } -function standalone(x) { +export function standalone(x) { return x * 2; } """) diff --git a/tests/test_languages/test_javascript_instrumentation.py b/tests/test_languages/test_javascript_instrumentation.py index ba25a3af5..27662bd59 100644 --- a/tests/test_languages/test_javascript_instrumentation.py +++ b/tests/test_languages/test_javascript_instrumentation.py @@ -663,4 +663,197 @@ def test_this_method_call_exact_output(self): expected = " return codeflash.capture('Class.fibonacci', '1', this.fibonacci.bind(this), n - 1);" assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}" - assert counter == 1 \ No newline at end of file + assert counter == 1 + + +class TestFixImportsInsideTestBlocks: + """Tests for fix_imports_inside_test_blocks function.""" + + def test_fix_named_import_inside_test_block(self): + """Test fixing named import inside test function.""" + from codeflash.languages.javascript.instrument import fix_imports_inside_test_blocks + + code = """ +test('should work', () => { + const mock = jest.fn(); + import { foo } from '../src/module'; + expect(foo()).toBe(true); +}); +""" + fixed = fix_imports_inside_test_blocks(code) + + assert "const { foo } = require('../src/module');" in fixed + assert "import { foo }" not in fixed + + def test_fix_default_import_inside_test_block(self): + """Test fixing default import inside test function.""" + from codeflash.languages.javascript.instrument import fix_imports_inside_test_blocks + + code = """ +test('should work', () => { + env.isTest.mockReturnValue(false); + import queuesModule from '../src/queue/queue'; + expect(queuesModule).toBeDefined(); +}); +""" + fixed = fix_imports_inside_test_blocks(code) + + assert "const queuesModule = require('../src/queue/queue');" in fixed + assert "import queuesModule from" not in fixed + + def test_fix_namespace_import_inside_test_block(self): + """Test fixing namespace import inside test function.""" + from codeflash.languages.javascript.instrument import fix_imports_inside_test_blocks + + code = """ +test('should work', () => { + import * as utils from '../src/utils'; + expect(utils.foo()).toBe(true); +}); +""" + fixed = fix_imports_inside_test_blocks(code) + + assert "const utils = require('../src/utils');" in fixed + assert "import * as utils" not in fixed + + def test_preserve_top_level_imports(self): + """Test that top-level imports are not modified.""" + from codeflash.languages.javascript.instrument import fix_imports_inside_test_blocks + + code = """ +import { jest, describe, test, expect } from '@jest/globals'; +import { foo } from '../src/module'; + +describe('test suite', () => { + test('should work', () => { + expect(foo()).toBe(true); + }); +}); +""" + fixed = fix_imports_inside_test_blocks(code) + + # Top-level imports should remain unchanged + assert "import { jest, describe, test, expect } from '@jest/globals';" in fixed + assert "import { foo } from '../src/module';" in fixed + + def test_empty_code(self): + """Test handling empty code.""" + from codeflash.languages.javascript.instrument import fix_imports_inside_test_blocks + + assert fix_imports_inside_test_blocks("") == "" + assert fix_imports_inside_test_blocks(" ") == " " + + +class TestFixJestMockPaths: + """Tests for fix_jest_mock_paths function.""" + + def test_fix_mock_path_when_source_relative(self): + """Test fixing mock path that's relative to source file.""" + from codeflash.languages.javascript.instrument import fix_jest_mock_paths + + with tempfile.TemporaryDirectory() as tmpdir: + # Create directory structure + src_dir = Path(tmpdir) / "src" / "queue" + tests_dir = Path(tmpdir) / "tests" + env_file = Path(tmpdir) / "src" / "environment.ts" + + src_dir.mkdir(parents=True) + tests_dir.mkdir(parents=True) + env_file.parent.mkdir(parents=True, exist_ok=True) + env_file.write_text("export const env = {};") + + source_file = src_dir / "queue.ts" + source_file.write_text("import env from '../environment';") + + test_file = tests_dir / "test_queue.test.ts" + + # Test code with incorrect mock path (relative to source, not test) + test_code = """ +import { jest, describe, test, expect } from '@jest/globals'; +jest.mock('../environment'); +jest.mock('../redis/utils'); + +describe('queue', () => { + test('works', () => {}); +}); +""" + fixed = fix_jest_mock_paths(test_code, test_file, source_file, tests_dir) + + # Should fix the path to be relative to the test file + assert "jest.mock('../src/environment')" in fixed + + def test_preserve_valid_mock_path(self): + """Test that valid mock paths are not modified.""" + from codeflash.languages.javascript.instrument import fix_jest_mock_paths + + with tempfile.TemporaryDirectory() as tmpdir: + # Create directory structure + src_dir = Path(tmpdir) / "src" + tests_dir = Path(tmpdir) / "tests" + + src_dir.mkdir(parents=True) + tests_dir.mkdir(parents=True) + + # Create the file being mocked at the correct location + mock_file = src_dir / "utils.ts" + mock_file.write_text("export const utils = {};") + + source_file = src_dir / "main.ts" + source_file.write_text("") + test_file = tests_dir / "test_main.test.ts" + + # Test code with correct mock path (valid from test location) + test_code = """ +jest.mock('../src/utils'); + +describe('main', () => { + test('works', () => {}); +}); +""" + fixed = fix_jest_mock_paths(test_code, test_file, source_file, tests_dir) + + # Should keep the path unchanged since it's valid + assert "jest.mock('../src/utils')" in fixed + + def test_fix_doMock_path(self): + """Test fixing jest.doMock path.""" + from codeflash.languages.javascript.instrument import fix_jest_mock_paths + + with tempfile.TemporaryDirectory() as tmpdir: + # Create directory structure: src/queue/queue.ts imports ../environment (-> src/environment.ts) + src_dir = Path(tmpdir) / "src" + queue_dir = src_dir / "queue" + tests_dir = Path(tmpdir) / "tests" + env_file = src_dir / "environment.ts" + + queue_dir.mkdir(parents=True) + tests_dir.mkdir(parents=True) + env_file.write_text("export const env = {};") + + source_file = queue_dir / "queue.ts" + source_file.write_text("") + test_file = tests_dir / "test_queue.test.ts" + + # From src/queue/queue.ts, ../environment resolves to src/environment.ts + # Test file is at tests/test_queue.test.ts + # So the correct mock path from test should be ../src/environment + test_code = """ +jest.doMock('../environment', () => ({ isTest: jest.fn() })); +""" + fixed = fix_jest_mock_paths(test_code, test_file, source_file, tests_dir) + + # Should fix the doMock path + assert "jest.doMock('../src/environment'" in fixed + + def test_empty_code(self): + """Test handling empty code.""" + from codeflash.languages.javascript.instrument import fix_jest_mock_paths + + with tempfile.TemporaryDirectory() as tmpdir: + tests_dir = Path(tmpdir) / "tests" + tests_dir.mkdir() + source_file = Path(tmpdir) / "src" / "main.ts" + test_file = tests_dir / "test.ts" + + assert fix_jest_mock_paths("", test_file, source_file, tests_dir) == "" + assert fix_jest_mock_paths(" ", test_file, source_file, tests_dir) == " " \ No newline at end of file diff --git a/tests/test_languages/test_javascript_optimization_flow.py b/tests/test_languages/test_javascript_optimization_flow.py index 7c7ba5aa6..26d2db140 100644 --- a/tests/test_languages/test_javascript_optimization_flow.py +++ b/tests/test_languages/test_javascript_optimization_flow.py @@ -60,6 +60,7 @@ def test_function_to_optimize_has_correct_language_for_javascript(self, tmp_path function add(a, b) { return a + b; } +module.exports = { add }; """) functions = find_all_functions_in_file(js_file) diff --git a/tests/test_languages/test_javascript_support.py b/tests/test_languages/test_javascript_support.py index fc7343e48..bdc4a4be5 100644 --- a/tests/test_languages/test_javascript_support.py +++ b/tests/test_languages/test_javascript_support.py @@ -46,7 +46,7 @@ def test_discover_simple_function(self, js_support): """Test discovering a simple function declaration.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -function add(a, b) { +export function add(a, b) { return a + b; } """) @@ -62,15 +62,15 @@ def test_discover_multiple_functions(self, js_support): """Test discovering multiple functions.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -function add(a, b) { +export function add(a, b) { return a + b; } -function subtract(a, b) { +export function subtract(a, b) { return a - b; } -function multiply(a, b) { +export function multiply(a, b) { return a * b; } """) @@ -86,11 +86,11 @@ def test_discover_arrow_function(self, js_support): """Test discovering arrow functions assigned to variables.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -const add = (a, b) => { +export const add = (a, b) => { return a + b; }; -const multiply = (x, y) => x * y; +export const multiply = (x, y) => x * y; """) f.flush() @@ -104,11 +104,11 @@ def test_discover_function_without_return_excluded(self, js_support): """Test that functions without return are excluded by default.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -function withReturn() { +export function withReturn() { return 1; } -function withoutReturn() { +export function withoutReturn() { console.log("hello"); } """) @@ -124,7 +124,7 @@ def test_discover_class_methods(self, js_support): """Test discovering class methods.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } @@ -147,11 +147,11 @@ def test_discover_async_functions(self, js_support): """Test discovering async functions.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -async function fetchData(url) { +export async function fetchData(url) { return await fetch(url); } -function syncFunction() { +export function syncFunction() { return 1; } """) @@ -171,11 +171,11 @@ def test_discover_with_filter_exclude_async(self, js_support): """Test filtering out async functions.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -async function asyncFunc() { +export async function asyncFunc() { return 1; } -function syncFunc() { +export function syncFunc() { return 2; } """) @@ -191,11 +191,11 @@ def test_discover_with_filter_exclude_methods(self, js_support): """Test filtering out class methods.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -function standalone() { +export function standalone() { return 1; } -class MyClass { +export class MyClass { method() { return 2; } @@ -212,11 +212,11 @@ class MyClass { def test_discover_line_numbers(self, js_support): """Test that line numbers are correctly captured.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""function func1() { + f.write("""export function func1() { return 1; } -function func2() { +export function func2() { const x = 1; const y = 2; return x + y; @@ -238,7 +238,7 @@ def test_discover_generator_function(self, js_support): """Test discovering generator functions.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -function* numberGenerator() { +export function* numberGenerator() { yield 1; yield 2; return 3; @@ -271,7 +271,7 @@ def test_discover_function_expression(self, js_support): """Test discovering function expressions.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -const add = function(a, b) { +export const add = function(a, b) { return a + b; }; """) @@ -290,7 +290,7 @@ def test_discover_immediately_invoked_function_excluded(self, js_support): return 1; })(); -function named() { +export function named() { return 2; } """) @@ -476,7 +476,7 @@ class TestExtractCodeContext: def test_extract_simple_function(self, js_support): """Test extracting context for a simple function.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""function add(a, b) { + f.write("""export function add(a, b) { return a + b; } """) @@ -495,11 +495,11 @@ def test_extract_simple_function(self, js_support): def test_extract_with_helper(self, js_support): """Test extracting context with helper functions.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""function helper(x) { + f.write("""export function helper(x) { return x * 2; } -function main(a) { +export function main(a) { return helper(a) + 1; } """) @@ -523,7 +523,7 @@ class TestIntegration: def test_discover_and_replace_workflow(self, js_support): """Test full discover -> replace workflow.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - original_code = """function fibonacci(n) { + original_code = """export function fibonacci(n) { if (n <= 1) { return n; } @@ -541,7 +541,7 @@ def test_discover_and_replace_workflow(self, js_support): assert func.function_name == "fibonacci" # Replace - optimized_code = """function fibonacci(n) { + optimized_code = """export function fibonacci(n) { // Memoized version const memo = {0: 0, 1: 1}; for (let i = 2; i <= n; i++) { @@ -561,7 +561,7 @@ def test_multiple_classes_and_functions(self, js_support): """Test discovering and working with complex file.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } @@ -571,13 +571,13 @@ class Calculator { } } -class StringUtils { +export class StringUtils { reverse(s) { return s.split('').reverse().join(''); } } -function standalone() { +export function standalone() { return 42; } """) @@ -605,11 +605,11 @@ def test_jsx_file(self, js_support): f.write(""" import React from 'react'; -function Button({ onClick, children }) { +export function Button({ onClick, children }) { return ; } -const Card = ({ title, content }) => { +export const Card = ({ title, content }) => { return (

{title}

@@ -673,7 +673,7 @@ class TestClassMethodExtraction: def test_extract_class_method_wraps_in_class(self, js_support): """Test that extracting a class method wraps it in a class definition.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class Calculator { + f.write("""export class Calculator { add(a, b) { return a + b; } @@ -694,6 +694,7 @@ def test_extract_class_method_wraps_in_class(self, js_support): context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) # Full string equality check for exact extraction output + # Note: export keyword is not included in extracted class wrapper expected_code = """class Calculator { add(a, b) { return a + b; @@ -709,7 +710,7 @@ def test_extract_class_method_with_jsdoc(self, js_support): f.write("""/** * A simple calculator class. */ -class Calculator { +export class Calculator { /** * Adds two numbers. * @param {number} a - First number @@ -730,10 +731,9 @@ class Calculator { context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) # Full string equality check - includes class JSDoc, class definition, method JSDoc, and method - expected_code = """/** - * A simple calculator class. - */ -class Calculator { + # Note: export keyword is not included in extracted class wrapper + # Note: Class-level JSDoc is not included when extracting a method + expected_code = """class Calculator { /** * Adds two numbers. * @param {number} a - First number @@ -751,7 +751,7 @@ class Calculator { def test_extract_class_method_syntax_valid(self, js_support): """Test that extracted class method code is always syntactically valid.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class FibonacciCalculator { + f.write("""export class FibonacciCalculator { fibonacci(n) { if (n <= 1) { return n; @@ -769,6 +769,7 @@ def test_extract_class_method_syntax_valid(self, js_support): context = js_support.extract_code_context(fib_method, file_path.parent, file_path.parent) # Full string equality check + # Note: export keyword is not included in extracted class wrapper expected_code = """class FibonacciCalculator { fibonacci(n) { if (n <= 1) { @@ -784,7 +785,7 @@ def test_extract_class_method_syntax_valid(self, js_support): def test_extract_nested_class_method(self, js_support): """Test extracting a method from a nested class structure.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class Outer { + f.write("""export class Outer { createInner() { return class Inner { getValue() { @@ -808,6 +809,7 @@ def test_extract_nested_class_method(self, js_support): context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) # Full string equality check + # Note: export keyword is not included in extracted class wrapper expected_code = """class Outer { add(a, b) { return a + b; @@ -820,7 +822,7 @@ def test_extract_nested_class_method(self, js_support): def test_extract_async_class_method(self, js_support): """Test extracting an async class method.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class ApiClient { + f.write("""export class ApiClient { async fetchData(url) { const response = await fetch(url); return response.json(); @@ -836,6 +838,7 @@ def test_extract_async_class_method(self, js_support): context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent) # Full string equality check + # Note: export keyword is not included in extracted class wrapper expected_code = """class ApiClient { async fetchData(url) { const response = await fetch(url); @@ -849,7 +852,7 @@ def test_extract_async_class_method(self, js_support): def test_extract_static_class_method(self, js_support): """Test extracting a static class method.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class MathUtils { + f.write("""export class MathUtils { static add(a, b) { return a + b; } @@ -869,6 +872,7 @@ def test_extract_static_class_method(self, js_support): context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) # Full string equality check + # Note: export keyword is not included in extracted class wrapper expected_code = """class MathUtils { static add(a, b) { return a + b; @@ -881,7 +885,7 @@ def test_extract_static_class_method(self, js_support): def test_extract_class_method_without_class_jsdoc(self, js_support): """Test extracting a method from a class without JSDoc.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class SimpleClass { + f.write("""export class SimpleClass { simpleMethod() { return "hello"; } @@ -896,6 +900,7 @@ def test_extract_class_method_without_class_jsdoc(self, js_support): context = js_support.extract_code_context(method, file_path.parent, file_path.parent) # Full string equality check + # Note: export keyword is not included in extracted class wrapper expected_code = """class SimpleClass { simpleMethod() { return "hello"; @@ -1061,7 +1066,7 @@ class TestClassMethodEdgeCases: def test_class_with_constructor(self, js_support): """Test handling classes with constructors.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class Counter { + f.write("""export class Counter { constructor(start = 0) { this.value = start; } @@ -1083,7 +1088,7 @@ def test_class_with_constructor(self, js_support): def test_class_with_getters_setters(self, js_support): """Test handling classes with getters and setters.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class Person { + f.write("""export class Person { constructor(name) { this._name = name; } @@ -1113,13 +1118,13 @@ def test_class_with_getters_setters(self, js_support): def test_class_extending_another(self, js_support): """Test handling classes that extend another class.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class Animal { + f.write("""export class Animal { speak() { return 'sound'; } } -class Dog extends Animal { +export class Dog extends Animal { speak() { return 'bark'; } @@ -1141,6 +1146,7 @@ class Dog extends Animal { context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent) # Full string equality check + # Note: export keyword is not included in extracted class wrapper expected_code = """class Dog { fetch() { return 'ball'; @@ -1153,7 +1159,7 @@ class Dog extends Animal { def test_class_with_private_method(self, js_support): """Test handling classes with private methods (ES2022+).""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class SecureClass { + f.write("""export class SecureClass { #privateMethod() { return 'secret'; } @@ -1175,7 +1181,7 @@ def test_class_with_private_method(self, js_support): def test_commonjs_class_export(self, js_support): """Test handling CommonJS exported classes.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class Calculator { + f.write("""export class Calculator { add(a, b) { return a + b; } @@ -1236,7 +1242,7 @@ def test_extract_context_then_replace_method(self, js_support): 3. Replace extracts just the method body and replaces in original """ original_source = """\ -class Counter { +export class Counter { constructor(initial = 0) { this.count = initial; } @@ -1303,7 +1309,7 @@ class Counter { # Verify result with exact string equality expected_result = """\ -class Counter { +export class Counter { constructor(initial = 0) { this.count = initial; } @@ -1333,7 +1339,7 @@ def test_typescript_extract_context_then_replace_method(self): ts_support = TypeScriptSupport() original_source = """\ -class User { +export class User { private name: string; private age: number; @@ -1350,8 +1356,6 @@ class User { return this.age; } } - -export { User }; """ with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f: f.write(original_source) @@ -1408,7 +1412,7 @@ class User { # Verify result with exact string equality expected_result = """\ -class User { +export class User { private name: string; private age: number; @@ -1426,8 +1430,6 @@ class User { return this.age; } } - -export { User }; """ assert result == expected_result, ( f"Replacement result does not match expected.\nExpected:\n{expected_result}\n\nGot:\n{result}" @@ -1437,7 +1439,7 @@ class User { def test_extract_replace_preserves_other_methods(self, js_support): """Test that replacing one method doesn't affect others.""" original_source = """\ -class Calculator { +export class Calculator { constructor(precision = 2) { this.precision = precision; } @@ -1499,7 +1501,7 @@ class Calculator { # Verify result with exact string equality expected_result = """\ -class Calculator { +export class Calculator { constructor(precision = 2) { this.precision = precision; } @@ -1525,7 +1527,7 @@ class Calculator { def test_extract_static_method_then_replace(self, js_support): """Test extracting and replacing a static method.""" original_source = """\ -class MathUtils { +export class MathUtils { constructor() { this.cache = {}; } @@ -1538,8 +1540,6 @@ class MathUtils { return a * b; } } - -module.exports = { MathUtils }; """ with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(original_source) @@ -1586,7 +1586,7 @@ class MathUtils { # Verify result with exact string equality expected_result = """\ -class MathUtils { +export class MathUtils { constructor() { this.cache = {}; } @@ -1600,8 +1600,6 @@ class MathUtils { return a * b; } } - -module.exports = { MathUtils }; """ assert result == expected_result, ( f"Replacement result does not match expected.\nExpected:\n{expected_result}\n\nGot:\n{result}" diff --git a/tests/test_languages/test_javascript_test_discovery.py b/tests/test_languages/test_javascript_test_discovery.py index 9166b589e..9126d1805 100644 --- a/tests/test_languages/test_javascript_test_discovery.py +++ b/tests/test_languages/test_javascript_test_discovery.py @@ -29,7 +29,7 @@ def test_discover_tests_basic(self, js_support): # Create source file source_file = tmpdir / "math.js" source_file.write_text(""" -function add(a, b) { +export function add(a, b) { return a + b; } @@ -71,7 +71,7 @@ def test_discover_tests_spec_suffix(self, js_support): # Create source file source_file = tmpdir / "calculator.js" source_file.write_text(""" -function multiply(a, b) { +export function multiply(a, b) { return a * b; } @@ -103,7 +103,7 @@ def test_discover_tests_in_tests_directory(self, js_support): # Create source file source_file = tmpdir / "utils.js" source_file.write_text(""" -function formatDate(date) { +export function formatDate(date) { return date.toISOString(); } @@ -136,11 +136,11 @@ def test_discover_tests_nested_describe(self, js_support): source_file = tmpdir / "string_utils.js" source_file.write_text(""" -function capitalize(str) { +export function capitalize(str) { return str.charAt(0).toUpperCase() + str.slice(1); } -function lowercase(str) { +export function lowercase(str) { return str.toLowerCase(); } @@ -186,7 +186,7 @@ def test_discover_tests_with_it_block(self, js_support): source_file = tmpdir / "array_utils.js" source_file.write_text(""" -function sum(arr) { +export function sum(arr) { return arr.reduce((a, b) => a + b, 0); } @@ -254,7 +254,7 @@ def test_discover_tests_default_export(self, js_support): source_file = tmpdir / "greeter.js" source_file.write_text(""" -function greet(name) { +export function greet(name) { return `Hello, ${name}!`; } @@ -282,7 +282,7 @@ def test_discover_tests_class_methods(self, js_support): source_file = tmpdir / "calculator_class.js" source_file.write_text(""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } @@ -333,7 +333,7 @@ def test_discover_tests_multi_level_directories(self, js_support): source_file = src_dir / "helpers.js" source_file.write_text(""" -function clamp(value, min, max) { +export function clamp(value, min, max) { return Math.min(Math.max(value, min), max); } @@ -375,11 +375,11 @@ def test_discover_tests_async_functions(self, js_support): source_file = tmpdir / "async_utils.js" source_file.write_text(""" -async function fetchData(url) { +export async function fetchData(url) { return await fetch(url).then(r => r.json()); } -async function delay(ms) { +export async function delay(ms) { return new Promise(resolve => setTimeout(resolve, ms)); } @@ -413,7 +413,7 @@ def test_discover_tests_jsx_component(self, js_support): source_file.write_text(""" import React from 'react'; -function Button({ onClick, children }) { +export function Button({ onClick, children }) { return ; } @@ -449,7 +449,7 @@ def test_discover_tests_no_matching_tests(self, js_support): source_file = tmpdir / "untested.js" source_file.write_text(""" -function untestedFunction() { +export function untestedFunction() { return 42; } @@ -479,11 +479,11 @@ def test_discover_tests_function_name_in_source(self, js_support): source_file = tmpdir / "validators.js" source_file.write_text(""" -function isEmail(str) { +export function isEmail(str) { return str.includes('@'); } -function isUrl(str) { +export function isUrl(str) { return str.startsWith('http'); } @@ -515,11 +515,11 @@ def test_discover_tests_multiple_test_files(self, js_support): source_file = tmpdir / "shared_utils.js" source_file.write_text(""" -function helper1() { +export function helper1() { return 1; } -function helper2() { +export function helper2() { return 2; } @@ -558,7 +558,7 @@ def test_discover_tests_template_literal_names(self, js_support): source_file = tmpdir / "format.js" source_file.write_text(""" -function formatNumber(n) { +export function formatNumber(n) { return n.toFixed(2); } @@ -587,7 +587,7 @@ def test_discover_tests_aliased_import(self, js_support): source_file = tmpdir / "transform.js" source_file.write_text(""" -function transformData(data) { +export function transformData(data) { return data.map(x => x * 2); } @@ -792,8 +792,8 @@ def test_require_named_import(self, js_support): source_file = tmpdir / "funcs.js" source_file.write_text(""" -function funcA() { return 1; } -function funcB() { return 2; } +export function funcA() { return 1; } +export function funcB() { return 2; } module.exports = { funcA, funcB }; """) @@ -846,7 +846,7 @@ def test_default_import(self, js_support): source_file = tmpdir / "default_export.js" source_file.write_text(""" -function mainFunc() { return 'main'; } +export function mainFunc() { return 'main'; } module.exports = mainFunc; """) @@ -875,7 +875,7 @@ def test_comments_in_test_file(self, js_support): source_file = tmpdir / "commented.js" source_file.write_text(""" -function compute() { return 42; } +export function compute() { return 42; } module.exports = { compute }; """) @@ -908,7 +908,7 @@ def test_test_file_with_syntax_error(self, js_support): source_file = tmpdir / "valid.js" source_file.write_text(""" -function validFunc() { return 1; } +export function validFunc() { return 1; } module.exports = { validFunc }; """) @@ -933,8 +933,8 @@ def test_function_with_same_name_as_jest_api(self, js_support): source_file = tmpdir / "conflict.js" source_file.write_text(""" -function test(value) { return value > 0; } -function describe(obj) { return JSON.stringify(obj); } +export function test(value) { return value > 0; } +export function describe(obj) { return JSON.stringify(obj); } module.exports = { test, describe }; """) @@ -962,7 +962,7 @@ def test_empty_test_directory(self, js_support): source_file = tmpdir / "lonely.js" source_file.write_text(""" -function lonelyFunc() { return 'alone'; } +export function lonelyFunc() { return 'alone'; } module.exports = { lonelyFunc }; """) @@ -980,14 +980,14 @@ def test_circular_imports(self, js_support): file_a = tmpdir / "moduleA.js" file_a.write_text(""" const { funcB } = require('./moduleB'); -function funcA() { return 'A' + (funcB ? funcB() : ''); } +export function funcA() { return 'A' + (funcB ? funcB() : ''); } module.exports = { funcA }; """) file_b = tmpdir / "moduleB.js" file_b.write_text(""" const { funcA } = require('./moduleA'); -function funcB() { return 'B'; } +export function funcB() { return 'B'; } module.exports = { funcB }; """) @@ -1126,17 +1126,17 @@ def test_full_discovery_workflow(self, js_support): # Source file source_file = src_dir / "utils.js" source_file.write_text(r""" -function validateEmail(email) { +export function validateEmail(email) { const re = /^[^\s@]+@[^\s@]+\.[^\s@]+$/; return re.test(email); } -function validatePhone(phone) { +export function validatePhone(phone) { const re = /^\d{10}$/; return re.test(phone); } -function formatName(first, last) { +export function formatName(first, last) { return `${first} ${last}`.trim(); } @@ -1197,7 +1197,7 @@ def test_discovery_with_fixtures(self, js_support): source_file = tmpdir / "database.js" source_file.write_text(""" -class Database { +export class Database { constructor() { this.data = []; } @@ -1259,13 +1259,13 @@ def test_test_file_imports_different_module(self, js_support): # Create two source files source_a = tmpdir / "moduleA.js" source_a.write_text(""" -function funcA() { return 'A'; } +export function funcA() { return 'A'; } module.exports = { funcA }; """) source_b = tmpdir / "moduleB.js" source_b.write_text(""" -function funcB() { return 'B'; } +export function funcB() { return 'B'; } module.exports = { funcB }; """) @@ -1296,9 +1296,9 @@ def test_test_file_imports_only_specific_function(self, js_support): source_file = tmpdir / "utils.js" source_file.write_text(""" -function funcOne() { return 1; } -function funcTwo() { return 2; } -function funcThree() { return 3; } +export function funcOne() { return 1; } +export function funcTwo() { return 2; } +export function funcThree() { return 3; } module.exports = { funcOne, funcTwo, funcThree }; """) @@ -1325,7 +1325,7 @@ def test_function_name_as_string_not_import(self, js_support): source_file = tmpdir / "target.js" source_file.write_text(""" -function targetFunc() { return 'target'; } +export function targetFunc() { return 'target'; } module.exports = { targetFunc }; """) @@ -1354,7 +1354,7 @@ def test_module_import_with_method_access(self, js_support): source_file = tmpdir / "math.js" source_file.write_text(""" -function calculate(x) { return x * 2; } +export function calculate(x) { return x * 2; } module.exports = { calculate }; """) @@ -1380,7 +1380,7 @@ def test_class_method_discovery_via_class_import(self, js_support): source_file = tmpdir / "myclass.js" source_file.write_text(""" -class MyClass { +export class MyClass { methodA() { return 'A'; } methodB() { return 'B'; } } @@ -1416,7 +1416,7 @@ def test_nested_module_structure(self, js_support): source_file = src_dir / "helpers.js" source_file.write_text(""" -function deepHelper() { return 'deep'; } +export function deepHelper() { return 'deep'; } module.exports = { deepHelper }; """) @@ -1574,9 +1574,9 @@ def test_multiple_functions_same_file_different_tests(self, js_support): source_file = tmpdir / "multiple.js" source_file.write_text(""" -function addNumbers(a, b) { return a + b; } -function subtractNumbers(a, b) { return a - b; } -function multiplyNumbers(a, b) { return a * b; } +export function addNumbers(a, b) { return a + b; } +export function subtractNumbers(a, b) { return a - b; } +export function multiplyNumbers(a, b) { return a * b; } module.exports = { addNumbers, subtractNumbers, multiplyNumbers }; """) @@ -1613,7 +1613,7 @@ def test_test_in_wrong_describe_still_discovered(self, js_support): source_file = tmpdir / "funcs.js" source_file.write_text(""" -function targetFunc() { return 'target'; } +export function targetFunc() { return 'target'; } module.exports = { targetFunc }; """) @@ -1705,7 +1705,7 @@ def test_class_method_qualified_name(self, js_support): source_file = tmpdir / "calculator.js" source_file.write_text(""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } subtract(a, b) { return a - b; } } @@ -1726,7 +1726,7 @@ def test_nested_class_method(self, js_support): source_file = tmpdir / "nested.js" source_file.write_text(""" -class Outer { +export class Outer { innerMethod() { class Inner { deepMethod() { return 'deep'; } diff --git a/tests/test_languages/test_js_code_extractor.py b/tests/test_languages/test_js_code_extractor.py index b1dcee81f..a21f15e2e 100644 --- a/tests/test_languages/test_js_code_extractor.py +++ b/tests/test_languages/test_js_code_extractor.py @@ -109,12 +109,7 @@ def test_extract_context_includes_direct_helpers(self, js_support, cjs_project): factorial_helper = helper_dict["factorial"] expected_factorial_code = """\ -/** - * Calculate factorial recursively. - * @param n - Non-negative integer - * @returns Factorial of n - */ -function factorial(n) { +export function factorial(n) { // Intentionally inefficient recursive implementation if (n <= 1) return 1; return n * factorial(n - 1); @@ -196,46 +191,22 @@ def test_extract_compound_interest_helpers(self, js_support, cjs_project): # STRICT: Verify each helper's code exactly expected_add_code = """\ -/** - * Add two numbers. - * @param a - First number - * @param b - Second number - * @returns Sum of a and b - */ -function add(a, b) { +export function add(a, b) { return a + b; }""" expected_multiply_code = """\ -/** - * Multiply two numbers. - * @param a - First number - * @param b - Second number - * @returns Product of a and b - */ -function multiply(a, b) { +export function multiply(a, b) { return a * b; }""" expected_format_number_code = """\ -/** - * Format a number to specified decimal places. - * @param num - Number to format - * @param decimals - Number of decimal places - * @returns Formatted number - */ -function formatNumber(num, decimals) { +export function formatNumber(num, decimals) { return Number(num.toFixed(decimals)); }""" expected_validate_input_code = """\ -/** - * Validate that input is a valid number. - * @param value - Value to validate - * @param name - Parameter name for error message - * @throws Error if value is not a valid number - */ -function validateInput(value, name) { +export function validateInput(value, name) { if (typeof value !== 'number' || isNaN(value)) { throw new Error(`Invalid ${name}: must be a number`); } @@ -317,13 +288,7 @@ class Calculator { assert set(helper_dict.keys()) == {"add"}, f"Expected 'add' helper, got: {list(helper_dict.keys())}" expected_add_code = """\ -/** - * Add two numbers. - * @param a - First number - * @param b - Second number - * @returns Sum of a and b - */ -function add(a, b) { +export function add(a, b) { return a + b; }""" @@ -702,7 +667,7 @@ def js_support(self): def test_standalone_function(self, js_support, tmp_path): """Test standalone function with no helpers.""" source = """\ -function standalone(x) { +export function standalone(x) { return x * 2; } @@ -718,7 +683,7 @@ def test_standalone_function(self, js_support, tmp_path): # STRICT: Exact code comparison expected_code = """\ -function standalone(x) { +export function standalone(x) { return x * 2; }""" assert context.target_code.strip() == expected_code.strip(), ( @@ -735,7 +700,7 @@ def test_external_package_excluded(self, js_support, tmp_path): source = """\ const _ = require('lodash'); -function processArray(arr) { +export function processArray(arr) { return _.map(arr, x => x * 2); } @@ -750,7 +715,7 @@ def test_external_package_excluded(self, js_support, tmp_path): context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path) expected_code = """\ -function processArray(arr) { +export function processArray(arr) { return _.map(arr, x => x * 2); }""" @@ -769,7 +734,7 @@ def test_external_package_excluded(self, js_support, tmp_path): def test_recursive_function(self, js_support, tmp_path): """Test recursive function doesn't list itself as helper.""" source = """\ -function fibonacci(n) { +export function fibonacci(n) { if (n <= 1) return n; return fibonacci(n - 1) + fibonacci(n - 2); } @@ -786,7 +751,7 @@ def test_recursive_function(self, js_support, tmp_path): # STRICT: Exact code comparison expected_code = """\ -function fibonacci(n) { +export function fibonacci(n) { if (n <= 1) return n; return fibonacci(n - 1) + fibonacci(n - 2); }""" @@ -803,7 +768,7 @@ def test_arrow_function_helper(self, js_support, tmp_path): source = """\ const helper = (x) => x * 2; -const processValue = (value) => { +export const processValue = (value) => { return helper(value) + 1; }; @@ -818,7 +783,7 @@ def test_arrow_function_helper(self, js_support, tmp_path): context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path) expected_code = """\ -const processValue = (value) => { +export const processValue = (value) => { return helper(value) + 1; };""" @@ -854,7 +819,7 @@ def ts_support(self): def test_method_extraction_includes_constructor(self, js_support, tmp_path): """Test that extracting a class method includes the constructor.""" source = """\ -class Counter { +export class Counter { constructor(initial = 0) { this.count = initial; } @@ -894,7 +859,7 @@ class Counter { def test_method_extraction_class_without_constructor(self, js_support, tmp_path): """Test extracting a method from a class that has no constructor.""" source = """\ -class MathUtils { +export class MathUtils { add(a, b) { return a + b; } @@ -928,7 +893,7 @@ class MathUtils { def test_typescript_method_extraction_includes_fields(self, ts_support, tmp_path): """Test that TypeScript method extraction includes class fields.""" source = """\ -class User { +export class User { private name: string; public age: number; @@ -941,8 +906,6 @@ class User { return this.name; } } - -export { User }; """ test_file = tmp_path / "user.ts" test_file.write_text(source) @@ -974,7 +937,7 @@ class User { def test_typescript_fields_only_no_constructor(self, ts_support, tmp_path): """Test TypeScript class with fields but no constructor.""" source = """\ -class Config { +export class Config { readonly apiUrl: string = "https://api.example.com"; timeout: number = 5000; @@ -982,8 +945,6 @@ class Config { return this.apiUrl; } } - -export { Config }; """ test_file = tmp_path / "config.ts" test_file.write_text(source) @@ -1010,7 +971,7 @@ class Config { def test_constructor_with_jsdoc(self, js_support, tmp_path): """Test that constructor with JSDoc is fully extracted.""" source = """\ -class Logger { +export class Logger { /** * Create a new Logger instance. * @param {string} prefix - The prefix to use for log messages. @@ -1056,7 +1017,7 @@ class Logger { def test_static_method_includes_constructor(self, js_support, tmp_path): """Test that static method extraction also includes constructor context.""" source = """\ -class Factory { +export class Factory { constructor(config) { this.config = config; } @@ -1212,13 +1173,11 @@ def test_extract_same_file_interface_from_parameter(self, ts_support, tmp_path): y: number; } -function distance(p1: Point, p2: Point): number { +export function distance(p1: Point, p2: Point): number { const dx = p2.x - p1.x; const dy = p2.y - p1.y; return Math.sqrt(dx * dx + dy * dy); } - -export { distance }; """ test_file = tmp_path / "geometry.ts" test_file.write_text(source) @@ -1251,7 +1210,7 @@ def test_extract_same_file_enum_from_parameter(self, ts_support, tmp_path): FAILURE = 'failure', } -function processStatus(status: Status): string { +export function processStatus(status: Status): string { switch (status) { case Status.PENDING: return 'Processing...'; @@ -1261,8 +1220,6 @@ def test_extract_same_file_enum_from_parameter(self, ts_support, tmp_path): return 'Failed!'; } } - -export { processStatus }; """ test_file = tmp_path / "status.ts" test_file.write_text(source) @@ -1295,11 +1252,9 @@ def test_extract_same_file_type_alias_from_return_type(self, ts_support, tmp_pat success: boolean; }; -function compute(x: number): Result { +export function compute(x: number): Result { return { value: x * 2, success: true }; } - -export { compute }; """ test_file = tmp_path / "compute.ts" test_file.write_text(source) @@ -1331,7 +1286,7 @@ def test_extract_class_field_types(self, ts_support, tmp_path): retries: number; } -class Service { +export class Service { private config: Config; constructor(config: Config) { @@ -1342,8 +1297,6 @@ class Service { return this.config.timeout; } } - -export { Service }; """ test_file = tmp_path / "service.ts" test_file.write_text(source) @@ -1372,11 +1325,9 @@ class Service { def test_primitive_types_not_included(self, ts_support, tmp_path): """Test that primitive types (number, string, etc.) are not extracted.""" source = """\ -function add(a: number, b: number): number { +export function add(a: number, b: number): number { return a + b; } - -export { add }; """ test_file = tmp_path / "add.ts" test_file.write_text(source) @@ -1405,11 +1356,9 @@ def test_extract_multiple_types(self, ts_support, tmp_path): height: number; } -function createRect(origin: Point, size: Size): { origin: Point; size: Size } { +export function createRect(origin: Point, size: Size): { origin: Point; size: Size } { return { origin, size }; } - -export { createRect }; """ test_file = tmp_path / "rect.ts" test_file.write_text(source) @@ -1447,7 +1396,7 @@ def test_extract_imported_type_definition(self, ts_support, ts_types_project): geometry_file.write_text("""\ import { Point, CalculationConfig } from './types'; -function calculateDistance(p1: Point, p2: Point, config: CalculationConfig): number { +export function calculateDistance(p1: Point, p2: Point, config: CalculationConfig): number { const dx = p2.x - p1.x; const dy = p2.y - p1.y; const distance = Math.sqrt(dx * dx + dy * dy); @@ -1458,8 +1407,6 @@ def test_extract_imported_type_definition(self, ts_support, ts_types_project): } return distance; } - -export { calculateDistance }; """) functions = ts_support.discover_functions(geometry_file) @@ -1506,11 +1453,9 @@ def test_type_with_jsdoc_included(self, ts_support, tmp_path): name: string; } -function greetUser(user: User): string { +export function greetUser(user: User): string { return `Hello, ${user.name}!`; } - -export { greetUser }; """ test_file = tmp_path / "user.ts" test_file.write_text(source) diff --git a/tests/test_languages/test_js_code_replacer.py b/tests/test_languages/test_js_code_replacer.py index c5b2cc001..9e251804a 100644 --- a/tests/test_languages/test_js_code_replacer.py +++ b/tests/test_languages/test_js_code_replacer.py @@ -757,7 +757,7 @@ class TestSimpleFunctionReplacement: def test_replace_simple_function_body(self, js_support, temp_project): """Test replacing a simple function body preserves structure exactly.""" original_source = """\ -function add(a, b) { +export function add(a, b) { return a + b; } """ @@ -769,7 +769,7 @@ def test_replace_simple_function_body(self, js_support, temp_project): # Optimized version with different body optimized_code = """\ -function add(a, b) { +export function add(a, b) { // Optimized: direct return return a + b; } @@ -778,7 +778,7 @@ def test_replace_simple_function_body(self, js_support, temp_project): result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function add(a, b) { +export function add(a, b) { // Optimized: direct return return a + b; } @@ -789,7 +789,7 @@ def test_replace_simple_function_body(self, js_support, temp_project): def test_replace_function_with_multiple_statements(self, js_support, temp_project): """Test replacing function with complex multi-statement body.""" original_source = """\ -function processData(data) { +export function processData(data) { const result = []; for (let i = 0; i < data.length; i++) { result.push(data[i] * 2); @@ -805,7 +805,7 @@ def test_replace_function_with_multiple_statements(self, js_support, temp_projec # Optimized version using map optimized_code = """\ -function processData(data) { +export function processData(data) { return data.map(x => x * 2); } """ @@ -813,7 +813,7 @@ def test_replace_function_with_multiple_statements(self, js_support, temp_projec result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function processData(data) { +export function processData(data) { return data.map(x => x * 2); } """ @@ -825,12 +825,12 @@ def test_replace_preserves_surrounding_code(self, js_support, temp_project): original_source = """\ const CONFIG = { debug: true }; -function targetFunction(x) { +export function targetFunction(x) { console.log(x); return x * 2; } -function otherFunction(y) { +export function otherFunction(y) { return y + 1; } @@ -843,7 +843,7 @@ def test_replace_preserves_surrounding_code(self, js_support, temp_project): target_func = next(f for f in functions if f.function_name == "targetFunction") optimized_code = """\ -function targetFunction(x) { +export function targetFunction(x) { return x << 1; } """ @@ -853,11 +853,11 @@ def test_replace_preserves_surrounding_code(self, js_support, temp_project): expected_result = """\ const CONFIG = { debug: true }; -function targetFunction(x) { +export function targetFunction(x) { return x << 1; } -function otherFunction(y) { +export function otherFunction(y) { return y + 1; } @@ -873,7 +873,7 @@ class TestClassMethodReplacement: def test_replace_class_method_body(self, js_support, temp_project): """Test replacing a class method body preserves class structure.""" original_source = """\ -class Calculator { +export class Calculator { constructor(precision = 2) { this.precision = precision; } @@ -896,7 +896,7 @@ class Calculator { # Optimized version provided in class context optimized_code = """\ -class Calculator { +export class Calculator { constructor(precision = 2) { this.precision = precision; } @@ -910,7 +910,7 @@ class Calculator { result = js_support.replace_function(original_source, add_method, optimized_code) expected_result = """\ -class Calculator { +export class Calculator { constructor(precision = 2) { this.precision = precision; } @@ -930,7 +930,7 @@ class Calculator { def test_replace_method_calling_sibling_methods(self, js_support, temp_project): """Test replacing method that calls other methods in same class.""" original_source = """\ -class DataProcessor { +export class DataProcessor { constructor() { this.cache = new Map(); } @@ -958,7 +958,7 @@ class DataProcessor { process_method = next(f for f in functions if f.function_name == "process") optimized_code = """\ -class DataProcessor { +export class DataProcessor { constructor() { this.cache = new Map(); } @@ -975,7 +975,7 @@ class DataProcessor { result = js_support.replace_function(original_source, process_method, optimized_code) expected_result = """\ -class DataProcessor { +export class DataProcessor { constructor() { this.cache = new Map(); } @@ -1008,7 +1008,7 @@ def test_replace_preserves_jsdoc_above_function(self, js_support, temp_project): * @param {number} b - Second number * @returns {number} The sum */ -function add(a, b) { +export function add(a, b) { const sum = a + b; return sum; } @@ -1020,13 +1020,7 @@ def test_replace_preserves_jsdoc_above_function(self, js_support, temp_project): func = functions[0] optimized_code = """\ -/** - * Calculates the sum of two numbers. - * @param {number} a - First number - * @param {number} b - Second number - * @returns {number} The sum - */ -function add(a, b) { +export function add(a, b) { return a + b; } """ @@ -1040,7 +1034,7 @@ def test_replace_preserves_jsdoc_above_function(self, js_support, temp_project): * @param {number} b - Second number * @returns {number} The sum */ -function add(a, b) { +export function add(a, b) { return a + b; } """ @@ -1054,7 +1048,7 @@ def test_replace_class_method_with_jsdoc(self, js_support, temp_project): * A simple cache implementation. * @class Cache */ -class Cache { +export class Cache { constructor() { this.data = new Map(); } @@ -1103,7 +1097,7 @@ class Cache { * A simple cache implementation. * @class Cache */ -class Cache { +export class Cache { constructor() { this.data = new Map(); } @@ -1128,7 +1122,7 @@ class TestAsyncFunctionReplacement: def test_replace_async_function_body(self, js_support, temp_project): """Test replacing async function preserves async keyword.""" original_source = """\ -async function fetchData(url) { +export async function fetchData(url) { const response = await fetch(url); const data = await response.json(); return data; @@ -1141,7 +1135,7 @@ def test_replace_async_function_body(self, js_support, temp_project): func = functions[0] optimized_code = """\ -async function fetchData(url) { +export async function fetchData(url) { return (await fetch(url)).json(); } """ @@ -1149,7 +1143,7 @@ def test_replace_async_function_body(self, js_support, temp_project): result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -async function fetchData(url) { +export async function fetchData(url) { return (await fetch(url)).json(); } """ @@ -1159,7 +1153,7 @@ def test_replace_async_function_body(self, js_support, temp_project): def test_replace_async_class_method(self, js_support, temp_project): """Test replacing async class method.""" original_source = """\ -class ApiClient { +export class ApiClient { constructor(baseUrl) { this.baseUrl = baseUrl; } @@ -1198,7 +1192,7 @@ class ApiClient { result = js_support.replace_function(original_source, get_method, optimized_code) expected_result = """\ -class ApiClient { +export class ApiClient { constructor(baseUrl) { this.baseUrl = baseUrl; } @@ -1220,7 +1214,7 @@ class TestGeneratorFunctionReplacement: def test_replace_generator_function_body(self, js_support, temp_project): """Test replacing generator function preserves generator syntax.""" original_source = """\ -function* range(start, end) { +export function* range(start, end) { for (let i = start; i < end; i++) { yield i; } @@ -1233,7 +1227,7 @@ def test_replace_generator_function_body(self, js_support, temp_project): func = functions[0] optimized_code = """\ -function* range(start, end) { +export function* range(start, end) { let i = start; while (i < end) yield i++; } @@ -1242,7 +1236,7 @@ def test_replace_generator_function_body(self, js_support, temp_project): result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function* range(start, end) { +export function* range(start, end) { let i = start; while (i < end) yield i++; } @@ -1257,7 +1251,7 @@ class TestTypeScriptReplacement: def test_replace_typescript_function_with_types(self, ts_support, temp_project): """Test replacing TypeScript function preserves type annotations.""" original_source = """\ -function processArray(items: number[]): number { +export function processArray(items: number[]): number { let sum = 0; for (let i = 0; i < items.length; i++) { sum += items[i]; @@ -1272,7 +1266,7 @@ def test_replace_typescript_function_with_types(self, ts_support, temp_project): func = functions[0] optimized_code = """\ -function processArray(items: number[]): number { +export function processArray(items: number[]): number { return items.reduce((a, b) => a + b, 0); } """ @@ -1280,7 +1274,7 @@ def test_replace_typescript_function_with_types(self, ts_support, temp_project): result = ts_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function processArray(items: number[]): number { +export function processArray(items: number[]): number { return items.reduce((a, b) => a + b, 0); } """ @@ -1290,7 +1284,7 @@ def test_replace_typescript_function_with_types(self, ts_support, temp_project): def test_replace_typescript_class_method_with_generics(self, ts_support, temp_project): """Test replacing TypeScript generic class method.""" original_source = """\ -class Container { +export class Container { private items: T[] = []; add(item: T): void { @@ -1325,7 +1319,7 @@ class Container { result = ts_support.replace_function(original_source, get_all_method, optimized_code) expected_result = """\ -class Container { +export class Container { private items: T[] = []; add(item: T): void { @@ -1349,7 +1343,7 @@ def test_replace_typescript_interface_typed_function(self, ts_support, temp_proj email: string; } -function createUser(name: string, email: string): User { +export function createUser(name: string, email: string): User { const id = Math.random().toString(36).substring(2, 15); const user: User = { id: id, @@ -1366,7 +1360,7 @@ def test_replace_typescript_interface_typed_function(self, ts_support, temp_proj func = next(f for f in functions if f.function_name == "createUser") optimized_code = """\ -function createUser(name: string, email: string): User { +export function createUser(name: string, email: string): User { return { id: Math.random().toString(36).substring(2, 15), name, @@ -1384,7 +1378,7 @@ def test_replace_typescript_interface_typed_function(self, ts_support, temp_proj email: string; } -function createUser(name: string, email: string): User { +export function createUser(name: string, email: string): User { return { id: Math.random().toString(36).substring(2, 15), name, @@ -1402,7 +1396,7 @@ class TestComplexReplacements: def test_replace_function_with_nested_functions(self, js_support, temp_project): """Test replacing function that contains nested function definitions.""" original_source = """\ -function processItems(items) { +export function processItems(items) { function helper(item) { return item * 2; } @@ -1421,7 +1415,7 @@ def test_replace_function_with_nested_functions(self, js_support, temp_project): process_func = next(f for f in functions if f.function_name == "processItems") optimized_code = """\ -function processItems(items) { +export function processItems(items) { const helper = x => x * 2; return items.map(helper); } @@ -1430,7 +1424,7 @@ def test_replace_function_with_nested_functions(self, js_support, temp_project): result = js_support.replace_function(original_source, process_func, optimized_code) expected_result = """\ -function processItems(items) { +export function processItems(items) { const helper = x => x * 2; return items.map(helper); } @@ -1441,7 +1435,7 @@ def test_replace_function_with_nested_functions(self, js_support, temp_project): def test_replace_multiple_methods_sequentially(self, js_support, temp_project): """Test replacing multiple methods in the same class sequentially.""" original_source = """\ -class MathUtils { +export class MathUtils { static sum(arr) { let total = 0; for (let i = 0; i < arr.length; i++) { @@ -1478,7 +1472,7 @@ class MathUtils { result = js_support.replace_function(original_source, sum_method, optimized_sum) expected_after_first = """\ -class MathUtils { +export class MathUtils { static sum(arr) { return arr.reduce((a, b) => a + b, 0); } @@ -1499,7 +1493,7 @@ class MathUtils { def test_replace_function_with_complex_destructuring(self, js_support, temp_project): """Test replacing function with complex parameter destructuring.""" original_source = """\ -function processConfig({ server: { host, port }, database: { url, poolSize } }) { +export function processConfig({ server: { host, port }, database: { url, poolSize } }) { const serverUrl = host + ':' + port; const dbConnection = url + '?poolSize=' + poolSize; return { @@ -1515,7 +1509,7 @@ def test_replace_function_with_complex_destructuring(self, js_support, temp_proj func = functions[0] optimized_code = """\ -function processConfig({ server: { host, port }, database: { url, poolSize } }) { +export function processConfig({ server: { host, port }, database: { url, poolSize } }) { return { server: `${host}:${port}`, db: `${url}?poolSize=${poolSize}` @@ -1526,7 +1520,7 @@ def test_replace_function_with_complex_destructuring(self, js_support, temp_proj result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function processConfig({ server: { host, port }, database: { url, poolSize } }) { +export function processConfig({ server: { host, port }, database: { url, poolSize } }) { return { server: `${host}:${port}`, db: `${url}?poolSize=${poolSize}` @@ -1543,7 +1537,7 @@ class TestEdgeCases: def test_replace_minimal_function_body(self, js_support, temp_project): """Test replacing function with minimal body.""" original_source = """\ -function minimal() { +export function minimal() { return null; } """ @@ -1554,7 +1548,7 @@ def test_replace_minimal_function_body(self, js_support, temp_project): func = functions[0] optimized_code = """\ -function minimal() { +export function minimal() { return { initialized: true, timestamp: Date.now() }; } """ @@ -1562,7 +1556,7 @@ def test_replace_minimal_function_body(self, js_support, temp_project): result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function minimal() { +export function minimal() { return { initialized: true, timestamp: Date.now() }; } """ @@ -1572,7 +1566,7 @@ def test_replace_minimal_function_body(self, js_support, temp_project): def test_replace_single_line_function(self, js_support, temp_project): """Test replacing single-line function.""" original_source = """\ -function identity(x) { return x; } +export function identity(x) { return x; } """ file_path = temp_project / "utils.js" file_path.write_text(original_source, encoding="utf-8") @@ -1581,13 +1575,13 @@ def test_replace_single_line_function(self, js_support, temp_project): func = functions[0] optimized_code = """\ -function identity(x) { return x ?? null; } +export function identity(x) { return x ?? null; } """ result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function identity(x) { return x ?? null; } +export function identity(x) { return x ?? null; } """ assert result == expected_result assert js_support.validate_syntax(result) is True @@ -1595,7 +1589,7 @@ def test_replace_single_line_function(self, js_support, temp_project): def test_replace_function_with_special_characters_in_strings(self, js_support, temp_project): """Test replacing function containing special characters in strings.""" original_source = """\ -function formatMessage(name) { +export function formatMessage(name) { const greeting = 'Hello, ' + name + '!'; const special = "Contains \\"quotes\\" and \\n newlines"; return greeting + ' ' + special; @@ -1608,7 +1602,7 @@ def test_replace_function_with_special_characters_in_strings(self, js_support, t func = functions[0] optimized_code = """\ -function formatMessage(name) { +export function formatMessage(name) { return `Hello, ${name}! Contains "quotes" and newlines`; } @@ -1617,7 +1611,7 @@ def test_replace_function_with_special_characters_in_strings(self, js_support, t result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function formatMessage(name) { +export function formatMessage(name) { return `Hello, ${name}! Contains "quotes" and newlines`; } @@ -1628,7 +1622,7 @@ def test_replace_function_with_special_characters_in_strings(self, js_support, t def test_replace_function_with_regex(self, js_support, temp_project): """Test replacing function containing regex patterns.""" original_source = """\ -function validateEmail(email) { +export function validateEmail(email) { const pattern = /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$/; if (pattern.test(email)) { return true; @@ -1643,7 +1637,7 @@ def test_replace_function_with_regex(self, js_support, temp_project): func = functions[0] optimized_code = """\ -function validateEmail(email) { +export function validateEmail(email) { return /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$/.test(email); } """ @@ -1651,7 +1645,7 @@ def test_replace_function_with_regex(self, js_support, temp_project): result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function validateEmail(email) { +export function validateEmail(email) { return /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$/.test(email); } """ @@ -1665,11 +1659,11 @@ class TestModuleExportHandling: def test_replace_exported_function_commonjs(self, js_support, temp_project): """Test replacing function in CommonJS module preserves exports.""" original_source = """\ -function helper(x) { +export function helper(x) { return x * 2; } -function main(data) { +export function main(data) { const results = []; for (let i = 0; i < data.length; i++) { results.push(helper(data[i])); @@ -1686,7 +1680,7 @@ def test_replace_exported_function_commonjs(self, js_support, temp_project): main_func = next(f for f in functions if f.function_name == "main") optimized_code = """\ -function main(data) { +export function main(data) { return data.map(helper); } """ @@ -1694,11 +1688,11 @@ def test_replace_exported_function_commonjs(self, js_support, temp_project): result = js_support.replace_function(original_source, main_func, optimized_code) expected_result = """\ -function helper(x) { +export function helper(x) { return x * 2; } -function main(data) { +export function main(data) { return data.map(helper); } @@ -1757,18 +1751,18 @@ def test_all_replacements_produce_valid_syntax(self, js_support, temp_project): test_cases = [ # (original, optimized, description) ( - "function f(x) { return x + 1; }", - "function f(x) { return ++x; }", + "export function f(x) { return x + 1; }", + "export function f(x) { return ++x; }", "increment replacement" ), ( - "function f(arr) { return arr.length > 0; }", - "function f(arr) { return !!arr.length; }", + "export function f(arr) { return arr.length > 0; }", + "export function f(arr) { return !!arr.length; }", "boolean conversion" ), ( - "function f(a, b) { if (a) { return a; } return b; }", - "function f(a, b) { return a || b; }", + "export function f(a, b) { if (a) { return a; } return b; }", + "export function f(a, b) { return a || b; }", "logical OR replacement" ), ] diff --git a/tests/test_languages/test_language_parity.py b/tests/test_languages/test_language_parity.py index ae57eb426..2b2035c84 100644 --- a/tests/test_languages/test_language_parity.py +++ b/tests/test_languages/test_language_parity.py @@ -38,7 +38,7 @@ def add(a, b): return a + b """, javascript=""" -function add(a, b) { +export function add(a, b) { return a + b; } """, @@ -58,15 +58,15 @@ def multiply(a, b): return a * b """, javascript=""" -function add(a, b) { +export function add(a, b) { return a + b; } -function subtract(a, b) { +export function subtract(a, b) { return a - b; } -function multiply(a, b) { +export function multiply(a, b) { return a * b; } """, @@ -83,11 +83,11 @@ def without_return(): print("hello") """, javascript=""" -function withReturn() { +export function withReturn() { return 1; } -function withoutReturn() { +export function withoutReturn() { console.log("hello"); } """, @@ -105,7 +105,7 @@ def multiply(self, a, b): return a * b """, javascript=""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } @@ -128,11 +128,11 @@ def sync_function(): return 1 """, javascript=""" -async function fetchData(url) { +export async function fetchData(url) { return await fetch(url); } -function syncFunction() { +export function syncFunction() { return 1; } """, @@ -148,7 +148,7 @@ def inner(): return inner() """, javascript=""" -function outer() { +export function outer() { function inner() { return 1; } @@ -167,7 +167,7 @@ def helper(x): return x * 2 """, javascript=""" -class Utils { +export class Utils { static helper(x) { return x * 2; } @@ -194,7 +194,7 @@ def standalone(): return 42 """, javascript=""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } @@ -204,13 +204,13 @@ class Calculator { } } -class StringUtils { +export class StringUtils { reverse(s) { return s.split('').reverse().join(''); } } -function standalone() { +export function standalone() { return 42; } """, @@ -227,11 +227,11 @@ def sync_func(): return 2 """, javascript=""" -async function asyncFunc() { +export async function asyncFunc() { return 1; } -function syncFunc() { +export function syncFunc() { return 2; } """, @@ -249,11 +249,11 @@ def method(self): return 2 """, javascript=""" -function standalone() { +export function standalone() { return 1; } -class MyClass { +export class MyClass { method() { return 2; } @@ -906,7 +906,7 @@ def test_discover_and_replace_workflow(self, python_support, js_support): return n return fibonacci(n - 1) + fibonacci(n - 2) """ - js_original = """function fibonacci(n) { + js_original = """export function fibonacci(n) { if (n <= 1) { return n; } @@ -933,7 +933,7 @@ def test_discover_and_replace_workflow(self, python_support, js_support): memo[i] = memo[i-1] + memo[i-2] return memo[n] """ - js_optimized = """function fibonacci(n) { + js_optimized = """export function fibonacci(n) { // Memoized version const memo = {0: 0, 1: 1}; for (let i = 2; i <= n; i++) { @@ -994,13 +994,13 @@ def test_function_info_fields_populated(self, python_support, js_support): def test_arrow_functions_unique_to_js(self, js_support): """JavaScript arrow functions should be discovered (no Python equivalent).""" js_code = """ -const add = (a, b) => { +export const add = (a, b) => { return a + b; }; -const multiply = (x, y) => x * y; +export const multiply = (x, y) => x * y; -const identity = x => x; +export const identity = x => x; """ js_file = write_temp_file(js_code, ".js") funcs = js_support.discover_functions(js_file) @@ -1021,7 +1021,7 @@ def number_generator(): return 3 """ js_code = """ -function* numberGenerator() { +export function* numberGenerator() { yield 1; yield 2; return 3; @@ -1065,11 +1065,11 @@ def multi_decorated(): def test_function_expressions_js(self, js_support): """JavaScript function expressions should be discovered.""" js_code = """ -const add = function(a, b) { +export const add = function(a, b) { return a + b; }; -const namedExpr = function myFunc(x) { +export const namedExpr = function myFunc(x) { return x * 2; }; """ @@ -1132,7 +1132,7 @@ def greeting(): return "Hello, 世界! 🌍" """ js_code = """ -function greeting() { +export function greeting() { return "Hello, 世界! 🌍"; } """ diff --git a/tests/test_languages/test_multi_file_code_replacer.py b/tests/test_languages/test_multi_file_code_replacer.py index 65f3930e5..b4d2854b6 100644 --- a/tests/test_languages/test_multi_file_code_replacer.py +++ b/tests/test_languages/test_multi_file_code_replacer.py @@ -168,6 +168,11 @@ def test_js_replcement() -> None: const { sumArray, average, findMax, findMin } = require('./math_helpers'); +/** + * Calculate statistics for an array of numbers. + * @param numbers - Array of numbers to analyze + * @returns Object containing sum, average, min, max, and range + */ /** * This is a modified comment */ @@ -211,7 +216,7 @@ def test_js_replcement() -> None: * @param numbers - Array of numbers to normalize * @returns Normalized array */ -function normalizeArray(numbers) { +export function normalizeArray(numbers) { if (numbers.length === 0) return []; const min = findMin(numbers); @@ -231,7 +236,7 @@ def test_js_replcement() -> None: * @param weights - Array of weights (same length as values) * @returns The weighted average */ -function weightedAverage(values, weights) { +export function weightedAverage(values, weights) { if (values.length === 0 || values.length !== weights.length) { return 0; } @@ -264,7 +269,7 @@ def test_js_replcement() -> None: * @param numbers - Array of numbers to sum * @returns The sum of all numbers */ -function sumArray(numbers) { +export function sumArray(numbers) { // Intentionally inefficient - using reduce with spread operator let result = 0; for (let i = 0; i < numbers.length; i++) { @@ -278,11 +283,16 @@ def test_js_replcement() -> None: * @param numbers - Array of numbers * @returns The average value */ -function average(numbers) { +export function average(numbers) { if (numbers.length === 0) return 0; return sumArray(numbers) / numbers.length; } +/** + * Find the maximum value in an array. + * @param numbers - Array of numbers + * @returns The maximum value + */ /** * Normalize an array of numbers to a 0-1 range. * @param numbers - Array of numbers to normalize @@ -301,6 +311,11 @@ def test_js_replcement() -> None: return max; } +/** + * Find the minimum value in an array. + * @param numbers - Array of numbers + * @returns The minimum value + */ /** * Find the minimum value in an array. * @param numbers - Array of numbers diff --git a/tests/test_languages/test_typescript_code_extraction.py b/tests/test_languages/test_typescript_code_extraction.py index f97049943..b344a2492 100644 --- a/tests/test_languages/test_typescript_code_extraction.py +++ b/tests/test_languages/test_typescript_code_extraction.py @@ -119,7 +119,7 @@ def test_extract_simple_function(self, ts_support): """Test extracting code context for a simple function.""" with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f: f.write(""" -function add(a: number, b: number): number { +export function add(a: number, b: number): number { return a + b; } """) @@ -147,7 +147,7 @@ def test_extract_async_function_with_template_literal(self, ts_support): const command_args = process.argv.slice(3); -async function execMongoEval(queryExpression, appsmithMongoURI) { +export async function execMongoEval(queryExpression, appsmithMongoURI) { queryExpression = queryExpression.trim(); if (command_args.includes("--pretty")) { @@ -186,7 +186,7 @@ def test_extract_function_with_complex_try_catch(self, ts_support): import fsPromises from "fs/promises"; import path from "path"; -async function figureOutContentsPath(root: string): Promise { +export async function figureOutContentsPath(root: string): Promise { const subfolders = await fsPromises.readdir(root, { withFileTypes: true }); try { @@ -238,7 +238,7 @@ def test_extracted_code_includes_imports(self, ts_support): import fs from "fs"; import path from "path"; -function readConfig(filename: string): string { +export function readConfig(filename: string): string { const fullPath = path.join(__dirname, filename); return fs.readFileSync(fullPath, "utf8"); } @@ -264,7 +264,7 @@ def test_extracted_code_includes_global_variables(self, ts_support): const CONFIG = { timeout: 5000 }; const MAX_RETRIES = 3; -async function fetchWithRetry(url: string): Promise { +export async function fetchWithRetry(url: string): Promise { for (let i = 0; i < MAX_RETRIES; i++) { try { const response = await fetch(url, { signal: AbortSignal.timeout(CONFIG.timeout) }); @@ -289,6 +289,164 @@ def test_extracted_code_includes_global_variables(self, ts_support): assert ts_support.validate_syntax(code_context.target_code) is True +class TestSameClassHelperExtraction: + """Tests for same-class helper method extraction. + + When a class method calls other methods from the same class, those helper + methods should be included inside the class wrapper (not appended outside), + because they may use class-specific syntax like 'private'. + """ + + def test_private_helper_method_inside_class_wrapper(self, ts_support): + """Test that private helper methods are included inside the class wrapper.""" + with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f: + # Export the class and add return statements so discover_functions finds the methods + f.write(""" +export class EndpointGroup { + private endpoints: any[] = []; + + constructor() { + this.endpoints = []; + } + + post(path: string, handler: Function): EndpointGroup { + this.addEndpoint("POST", path, handler); + return this; + } + + private addEndpoint(method: string, path: string, handler: Function): void { + this.endpoints.push({ method, path, handler }); + return; + } +} +""") + f.flush() + file_path = Path(f.name) + + # Discover the 'post' method + functions = ts_support.discover_functions(file_path) + post_method = None + for func in functions: + if func.function_name == "post": + post_method = func + break + + assert post_method is not None, "post method should be discovered" + + # Extract code context + code_context = ts_support.extract_code_context( + post_method, file_path.parent, file_path.parent + ) + + # The extracted code should be syntactically valid + assert ts_support.validate_syntax(code_context.target_code) is True, ( + f"Extracted code should be valid TypeScript:\n{code_context.target_code}" + ) + + # Both post and addEndpoint should be inside the class + assert "class EndpointGroup" in code_context.target_code + assert "post(" in code_context.target_code + assert "private addEndpoint" in code_context.target_code + + # The private method should be inside the class, not outside + # Check that addEndpoint appears BEFORE the closing brace of the class + class_end_index = code_context.target_code.rfind("}") + add_endpoint_index = code_context.target_code.find("addEndpoint") + assert add_endpoint_index < class_end_index, ( + "addEndpoint should be inside the class wrapper" + ) + + def test_multiple_private_helpers_inside_class(self, ts_support): + """Test that multiple private helpers are all included inside the class.""" + with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f: + f.write(""" +export class Router { + private routes: Map = new Map(); + + addRoute(path: string, handler: Function): boolean { + const normalizedPath = this.normalizePath(path); + this.validatePath(normalizedPath); + this.routes.set(normalizedPath, handler); + return true; + } + + private normalizePath(path: string): string { + return path.toLowerCase().trim(); + } + + private validatePath(path: string): boolean { + if (!path.startsWith("/")) { + throw new Error("Path must start with /"); + } + return true; + } +} +""") + f.flush() + file_path = Path(f.name) + + # Discover the 'addRoute' method + functions = ts_support.discover_functions(file_path) + add_route_method = None + for func in functions: + if func.function_name == "addRoute": + add_route_method = func + break + + assert add_route_method is not None + + code_context = ts_support.extract_code_context( + add_route_method, file_path.parent, file_path.parent + ) + + # Should be valid TypeScript + assert ts_support.validate_syntax(code_context.target_code) is True + + # All methods should be inside the class + assert "private normalizePath" in code_context.target_code + assert "private validatePath" in code_context.target_code + + def test_same_class_helpers_filtered_from_helper_list(self, ts_support): + """Test that same-class helpers are not duplicated in the helpers list.""" + with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f: + f.write(""" +export class Calculator { + add(a: number, b: number): number { + return this.compute(a, b, "+"); + } + + private compute(a: number, b: number, op: string): number { + if (op === "+") return a + b; + return 0; + } +} +""") + f.flush() + file_path = Path(f.name) + + functions = ts_support.discover_functions(file_path) + add_method = None + for func in functions: + if func.function_name == "add": + add_method = func + break + + assert add_method is not None + + code_context = ts_support.extract_code_context( + add_method, file_path.parent, file_path.parent + ) + + # 'compute' should be in target_code (inside class) + assert "compute" in code_context.target_code + + # 'compute' should NOT be in helper_functions (would be duplicate) + helper_names = [h.name for h in code_context.helper_functions] + assert "compute" not in helper_names, ( + "Same-class helper 'compute' should not be in helper_functions list" + ) + + class TestTypeScriptLanguageProperties: """Tests for TypeScript language support properties.""" diff --git a/tests/test_languages/test_typescript_e2e.py b/tests/test_languages/test_typescript_e2e.py index 199094a1d..a638f01a1 100644 --- a/tests/test_languages/test_typescript_e2e.py +++ b/tests/test_languages/test_typescript_e2e.py @@ -285,7 +285,7 @@ def test_function_to_optimize_has_correct_fields(self): with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f: f.write(""" -class Calculator { +export class Calculator { add(a: number, b: number): number { return a + b; } @@ -295,7 +295,7 @@ class Calculator { } } -function standalone(x: number): number { +export function standalone(x: number): number { return x * 2; } """)