diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index dab746c47..fdc5a420a 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -1,12 +1,14 @@ from __future__ import annotations import logging +from collections import deque from contextlib import contextmanager from itertools import cycle from typing import TYPE_CHECKING, Optional from rich.console import Console from rich.logging import RichHandler +from rich.panel import Panel from rich.progress import ( BarColumn, MofNCompleteColumn, @@ -24,10 +26,13 @@ from codeflash.lsp.lsp_message import LspCodeMessage, LspTextMessage if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Callable, Generator + from pathlib import Path from rich.progress import TaskID + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.languages.base import DependencyResolver, IndexResult from codeflash.lsp.lsp_message import LspMessage DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG @@ -196,3 +201,151 @@ def test_files_progress_bar(total: int, description: str) -> Generator[tuple[Pro ) as progress: task_id = progress.add_task(description, total=total) yield progress, task_id + + +MAX_TREE_ENTRIES = 8 + + +@contextmanager +def call_graph_live_display( + total: int, project_root: Path | None = None +) -> Generator[Callable[[IndexResult], None], None, None]: + from rich.console import Group + from rich.live import Live + from rich.panel import Panel + from rich.text import Text + from rich.tree import Tree + + if is_LSP_enabled(): + lsp_log(LspTextMessage(text="Building call graph", takes_time=True)) + yield lambda _: None + return + + progress = Progress( + SpinnerColumn(next(spinners)), + TextColumn("[progress.description]{task.description}"), + BarColumn(complete_style="cyan", finished_style="green", pulse_style="yellow"), + MofNCompleteColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), + auto_refresh=False, + ) + task_id = progress.add_task("Analyzing files", total=total) + + results: deque[IndexResult] = deque(maxlen=MAX_TREE_ENTRIES) + stats = {"indexed": 0, "cached": 0, "edges": 0, "external": 0, "errors": 0} + + tree = Tree("[bold]Recent Files[/bold]") + stats_text = Text("0 calls found", style="dim") + panel = Panel( + Group(progress, Text(""), tree, Text(""), stats_text), title="Building Call Graph", border_style="cyan" + ) + + def create_tree_node(result: IndexResult) -> Tree: + if project_root: + try: + name = str(result.file_path.resolve().relative_to(project_root.resolve())) + except ValueError: + name = f"{result.file_path.parent.name}/{result.file_path.name}" + else: + name = f"{result.file_path.parent.name}/{result.file_path.name}" + + if result.error: + return Tree(f"[red]{name} (error)[/red]") + + if result.cached: + return Tree(f"[dim]{name} (cached)[/dim]") + + local_edges = result.num_edges - result.cross_file_edges + edge_info = [] + + if local_edges: + edge_info.append(f"{local_edges} calls in same file") + if result.cross_file_edges: + edge_info.append(f"{result.cross_file_edges} calls from other modules") + + label = ", ".join(edge_info) if edge_info else "no calls" + return Tree(f"[cyan]{name}[/cyan] [dim]{label}[/dim]") + + def refresh_display() -> None: + tree.children = [create_tree_node(r) for r in results] + tree.children.extend([Tree(" ")] * (MAX_TREE_ENTRIES - len(results))) + + # Update stats + stat_parts = [] + if stats["indexed"]: + stat_parts.append(f"{stats['indexed']} files analyzed") + if stats["cached"]: + stat_parts.append(f"{stats['cached']} cached") + if stats["errors"]: + stat_parts.append(f"{stats['errors']} errors") + stat_parts.append(f"{stats['edges']} calls found") + if stats["external"]: + stat_parts.append(f"{stats['external']} cross-file calls") + + stats_text.truncate(0) + stats_text.append(" · ".join(stat_parts), style="dim") + + batch: list[IndexResult] = [] + + def process_batch() -> None: + for result in batch: + results.append(result) + + if result.error: + stats["errors"] += 1 + elif result.cached: + stats["cached"] += 1 + else: + stats["indexed"] += 1 + stats["edges"] += result.num_edges + stats["external"] += result.cross_file_edges + + progress.advance(task_id) + + batch.clear() + refresh_display() + live.refresh() + + def update(result: IndexResult) -> None: + batch.append(result) + if len(batch) >= 8: + process_batch() + + with Live(panel, console=console, transient=False, auto_refresh=False) as live: + yield update + if batch: + process_batch() + + +def call_graph_summary(call_graph: DependencyResolver, file_to_funcs: dict[Path, list[FunctionToOptimize]]) -> None: + total_functions = sum(map(len, file_to_funcs.values())) + if not total_functions: + return + + # Build the mapping expected by the dependency resolver + file_items = file_to_funcs.items() + mapping = {file_path: {func.qualified_name for func in funcs} for file_path, funcs in file_items} + + callee_counts = call_graph.count_callees_per_function(mapping) + + # Use built-in sum for C-level loops to reduce Python overhead + total_callees = sum(callee_counts.values()) + with_context = sum(1 for count in callee_counts.values() if count > 0) + + leaf_functions = total_functions - with_context + avg_callees = total_callees / total_functions + + function_label = "function" if total_functions == 1 else "functions" + + summary = ( + f"{total_functions} {function_label} ready for optimization\n" + f"Uses other functions: {with_context} · " + f"Standalone: {leaf_functions}" + ) + + if is_LSP_enabled(): + lsp_log(LspTextMessage(text=summary)) + return + + console.print(Panel(summary, title="Call Graph Summary", border_style="cyan")) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index d4478207c..3ad5eba2d 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -871,8 +871,7 @@ def replace_optimized_code( [ callee.qualified_name for callee in code_context.helper_functions - if callee.file_path == module_path - and (callee.jedi_definition is None or callee.jedi_definition.type != "class") + if callee.file_path == module_path and callee.definition_type != "class" ] ), candidate.source_code, diff --git a/codeflash/code_utils/compat.py b/codeflash/code_utils/compat.py index eb4e5b561..b73a6a5a7 100644 --- a/codeflash/code_utils/compat.py +++ b/codeflash/code_utils/compat.py @@ -2,46 +2,16 @@ import sys import tempfile from pathlib import Path -from typing import TYPE_CHECKING from platformdirs import user_config_dir -if TYPE_CHECKING: - codeflash_temp_dir: Path - codeflash_cache_dir: Path - codeflash_cache_db: Path +LF: str = os.linesep +IS_POSIX: bool = os.name != "nt" +SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix() +codeflash_cache_dir: Path = Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True)) -class Compat: - # os-independent newline - LF: str = os.linesep +codeflash_temp_dir: Path = Path(tempfile.gettempdir()) / "codeflash" +codeflash_temp_dir.mkdir(parents=True, exist_ok=True) - SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix() - - IS_POSIX: bool = os.name != "nt" - - @property - def codeflash_cache_dir(self) -> Path: - return Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True)) - - @property - def codeflash_temp_dir(self) -> Path: - temp_dir = Path(tempfile.gettempdir()) / "codeflash" - if not temp_dir.exists(): - temp_dir.mkdir(parents=True, exist_ok=True) - return temp_dir - - @property - def codeflash_cache_db(self) -> Path: - return self.codeflash_cache_dir / "codeflash_cache.db" - - -_compat = Compat() - - -codeflash_temp_dir = _compat.codeflash_temp_dir -codeflash_cache_dir = _compat.codeflash_cache_dir -codeflash_cache_db = _compat.codeflash_cache_db -LF = _compat.LF -SAFE_SYS_EXECUTABLE = _compat.SAFE_SYS_EXECUTABLE -IS_POSIX = _compat.IS_POSIX +codeflash_cache_db: Path = codeflash_cache_dir / "codeflash_cache.db" diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index b84a136d8..7fd8814d6 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -4,8 +4,8 @@ from typing import Any, Union MAX_TEST_RUN_ITERATIONS = 5 -OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 48000 -TESTGEN_CONTEXT_TOKEN_LIMIT = 48000 +OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 64000 +TESTGEN_CONTEXT_TOKEN_LIMIT = 64000 INDIVIDUAL_TESTCASE_TIMEOUT = 15 MAX_FUNCTION_TEST_SECONDS = 60 MIN_IMPROVEMENT_THRESHOLD = 0.05 diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 14455b890..56504875a 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -114,32 +114,30 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None: ) -class FunctionWithReturnStatement(ast.NodeVisitor): - def __init__(self, file_path: Path) -> None: - self.functions: list[FunctionToOptimize] = [] - self.ast_path: list[FunctionParent] = [] - self.file_path: Path = file_path - - def visit_FunctionDef(self, node: FunctionDef) -> None: - if function_has_return_statement(node) and not function_is_a_property(node): - self.functions.append( - FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:]) - ) - - def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None: - if function_has_return_statement(node) and not function_is_a_property(node): - self.functions.append( - FunctionToOptimize( - function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True +def find_functions_with_return_statement(ast_module: ast.Module, file_path: Path) -> list[FunctionToOptimize]: + results: list[FunctionToOptimize] = [] + # (node, parent_path) — iterative DFS avoids RecursionError on deeply nested ASTs + stack: list[tuple[ast.AST, list[FunctionParent]]] = [(ast_module, [])] + while stack: + node, ast_path = stack.pop() + if isinstance(node, (FunctionDef, AsyncFunctionDef)): + if function_has_return_statement(node) and not function_is_a_property(node): + results.append( + FunctionToOptimize( + function_name=node.name, + file_path=file_path, + parents=ast_path[:], + is_async=isinstance(node, AsyncFunctionDef), + ) ) - ) - - def generic_visit(self, node: ast.AST) -> None: - if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)): - self.ast_path.append(FunctionParent(node.name, node.__class__.__name__)) - super().generic_visit(node) - if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)): - self.ast_path.pop() + # Don't recurse into function bodies (matches original visitor behaviour) + continue + child_path = ( + [*ast_path, FunctionParent(node.name, node.__class__.__name__)] if isinstance(node, ClassDef) else ast_path + ) + for child in reversed(list(ast.iter_child_nodes(node))): + stack.append((child, child_path)) + return results # ============================================================================= @@ -265,9 +263,7 @@ def _find_all_functions_in_python_file(file_path: Path) -> dict[Path, list[Funct if DEBUG_MODE: logger.exception(e) return functions - function_name_visitor = FunctionWithReturnStatement(file_path) - function_name_visitor.visit(ast_module) - functions[file_path] = function_name_visitor.functions + functions[file_path] = find_functions_with_return_statement(ast_module, file_path) return functions @@ -992,12 +988,21 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool: # Custom DFS, return True as soon as a Return node is found - stack: list[ast.AST] = [function_node] + stack: list[ast.AST] = list(function_node.body) while stack: node = stack.pop() if isinstance(node, ast.Return): return True - stack.extend(ast.iter_child_nodes(node)) + # Only push child nodes that are statements; Return nodes are statements, + # so this preserves correctness while avoiding unnecessary traversal into expr/Name/etc. + for field in getattr(node, "_fields", ()): + child = getattr(node, field, None) + if isinstance(child, list): + for item in child: + if isinstance(item, ast.stmt): + stack.append(item) + elif isinstance(child, ast.stmt): + stack.append(child) return False diff --git a/codeflash/languages/__init__.py b/codeflash/languages/__init__.py index 47136f4e7..daf33b43c 100644 --- a/codeflash/languages/__init__.py +++ b/codeflash/languages/__init__.py @@ -19,7 +19,9 @@ from codeflash.languages.base import ( CodeContext, + DependencyResolver, HelperFunction, + IndexResult, Language, LanguageSupport, ParentInfo, @@ -82,8 +84,10 @@ def __getattr__(name: str): __all__ = [ "CodeContext", + "DependencyResolver", "FunctionInfo", "HelperFunction", + "IndexResult", "Language", "LanguageSupport", "ParentInfo", diff --git a/codeflash/languages/base.py b/codeflash/languages/base.py index 699c49244..8542547a4 100644 --- a/codeflash/languages/base.py +++ b/codeflash/languages/base.py @@ -11,10 +11,11 @@ from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Iterable, Sequence from pathlib import Path from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.models.models import FunctionSource from codeflash.languages.language_enum import Language from codeflash.models.function_types import FunctionParent @@ -34,6 +35,16 @@ def __getattr__(name: str) -> Any: raise AttributeError(msg) +@dataclass(frozen=True) +class IndexResult: + file_path: Path + cached: bool + num_edges: int + edges: tuple[tuple[str, str, bool], ...] # (caller_qn, callee_name, is_cross_file) + cross_file_edges: int + error: bool + + @dataclass class HelperFunction: """A helper function that is a dependency of the target function. @@ -192,6 +203,35 @@ class ReferenceInfo: caller_function: str | None = None +@runtime_checkable +class DependencyResolver(Protocol): + """Protocol for language-specific dependency resolution. + + Implementations analyze source files to discover call-graph edges + between functions so the optimizer can extract richer context. + """ + + def build_index(self, file_paths: Iterable[Path], on_progress: Callable[[IndexResult], None] | None = None) -> None: + """Pre-index a batch of files.""" + ... + + def get_callees( + self, file_path_to_qualified_names: dict[Path, set[str]] + ) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]: + """Return callees for the given functions.""" + ... + + def count_callees_per_function( + self, file_path_to_qualified_names: dict[Path, set[str]] + ) -> dict[tuple[Path, str], int]: + """Return the number of callees for each (file_path, qualified_name) pair.""" + ... + + def close(self) -> None: + """Release resources (e.g. database connections).""" + ... + + @runtime_checkable class LanguageSupport(Protocol): """Protocol defining what a language implementation must provide. @@ -564,6 +604,15 @@ def ensure_runtime_environment(self, project_root: Path) -> bool: # Default implementation: just copy runtime files return False + def create_dependency_resolver(self, project_root: Path) -> DependencyResolver | None: + """Create a language-specific dependency resolver, if available. + + Returns: + A DependencyResolver instance, or None if not supported. + + """ + return None + def instrument_existing_test( self, test_path: Path, diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index 77c6a3a80..cde098cab 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -2004,6 +2004,9 @@ def ensure_runtime_environment(self, project_root: Path) -> bool: logger.error("Could not install codeflash. Please run: npm install --save-dev codeflash") return False + def create_dependency_resolver(self, project_root: Path) -> None: + return None + def instrument_existing_test( self, test_path: Path, diff --git a/codeflash/languages/python/__init__.py b/codeflash/languages/python/__init__.py index e599d1431..939d5941f 100644 --- a/codeflash/languages/python/__init__.py +++ b/codeflash/languages/python/__init__.py @@ -5,6 +5,7 @@ to the LanguageSupport protocol. """ +from codeflash.languages.python.reference_graph import ReferenceGraph from codeflash.languages.python.support import PythonSupport -__all__ = ["PythonSupport"] +__all__ = ["PythonSupport", "ReferenceGraph"] diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index 9a4daf726..0e42022f6 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: from jedi.api.classes import Name - from codeflash.languages.base import HelperFunction + from codeflash.languages.base import DependencyResolver, HelperFunction from codeflash.languages.python.context.unused_definition_remover import UsageInfo # Error message constants @@ -80,6 +80,7 @@ def get_code_optimization_context( project_root_path: Path, optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT, testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT, + call_graph: DependencyResolver | None = None, ) -> CodeOptimizationContext: # Route to language-specific implementation for non-Python languages if not is_python(): @@ -88,9 +89,11 @@ def get_code_optimization_context( ) # Get FunctionSource representation of helpers of FTO - helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi( - {function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path - ) + fto_input = {function_to_optimize.file_path: {function_to_optimize.qualified_name}} + if call_graph is not None: + helpers_of_fto_dict, helpers_of_fto_list = call_graph.get_callees(fto_input) + else: + helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(fto_input, project_root_path) # Add function to optimize into helpers of FTO dict, as they'll be processed together fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path) @@ -252,7 +255,6 @@ def get_code_optimization_context_for_language( fully_qualified_name=helper.qualified_name, only_function_name=helper.name, source_code=helper.source_code, - jedi_definition=None, ) ) @@ -474,7 +476,6 @@ def get_function_to_optimize_as_function_source( fully_qualified_name=name.full_name, only_function_name=name.name, source_code=name.get_line_code(), - jedi_definition=name, ) except Exception as e: logger.exception(f"Error while getting function source: {e}") @@ -542,7 +543,6 @@ def get_function_sources_from_jedi( fully_qualified_name=fqn, only_function_name=func_name, source_code=definition.get_line_code(), - jedi_definition=definition, ) file_path_to_function_source[definition_path].add(function_source) function_source_list.append(function_source) diff --git a/codeflash/languages/python/context/unused_definition_remover.py b/codeflash/languages/python/context/unused_definition_remover.py index ba6e4d549..3cc7c173a 100644 --- a/codeflash/languages/python/context/unused_definition_remover.py +++ b/codeflash/languages/python/context/unused_definition_remover.py @@ -643,15 +643,31 @@ def _analyze_imports_in_optimized_code( helpers_by_file_and_func = defaultdict(dict) helpers_by_file = defaultdict(list) # preserved for "import module" for helper in code_context.helper_functions: - jedi_type = helper.jedi_definition.type if helper.jedi_definition else None - if jedi_type != "class": # Include when jedi_definition is None (non-Python) + jedi_type = helper.definition_type + if jedi_type != "class": # Include when definition_type is None (non-Python) func_name = helper.only_function_name module_name = helper.file_path.stem # Cache function lookup for this (module, func) helpers_by_file_and_func[module_name].setdefault(func_name, []).append(helper) helpers_by_file[module_name].append(helper) - for node in ast.walk(optimized_ast): + # Collect only import nodes to avoid per-node isinstance checks across the whole AST + class _ImportCollector(ast.NodeVisitor): + def __init__(self) -> None: + self.nodes: list[ast.AST] = [] + + def visit_Import(self, node: ast.Import) -> None: + self.nodes.append(node) + # No need to recurse further for import nodes + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + self.nodes.append(node) + # No need to recurse further for import-from nodes + + collector = _ImportCollector() + collector.visit(optimized_ast) + + for node in collector.nodes: if isinstance(node, ast.ImportFrom): # Handle "from module import function" statements module_name = node.module @@ -802,8 +818,8 @@ def detect_unused_helper_functions( unused_helpers = [] entrypoint_file_path = function_to_optimize.file_path for helper_function in code_context.helper_functions: - jedi_type = helper_function.jedi_definition.type if helper_function.jedi_definition else None - if jedi_type != "class": # Include when jedi_definition is None (non-Python) + jedi_type = helper_function.definition_type + if jedi_type != "class": # Include when definition_type is None (non-Python) # Check if the helper function is called using multiple name variants helper_qualified_name = helper_function.qualified_name helper_simple_name = helper_function.only_function_name diff --git a/codeflash/languages/python/reference_graph.py b/codeflash/languages/python/reference_graph.py new file mode 100644 index 000000000..4f389fd66 --- /dev/null +++ b/codeflash/languages/python/reference_graph.py @@ -0,0 +1,544 @@ +from __future__ import annotations + +import hashlib +import os +import sqlite3 +from collections import defaultdict +from pathlib import Path +from typing import TYPE_CHECKING + +from codeflash.cli_cmds.console import logger +from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages +from codeflash.languages.base import IndexResult +from codeflash.models.models import FunctionSource + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + from jedi.api.classes import Name + + +# --------------------------------------------------------------------------- +# Module-level helpers (must be top-level for ProcessPoolExecutor pickling) +# --------------------------------------------------------------------------- +# TODO: create call graph. + +_PARALLEL_THRESHOLD = 8 + +# Per-worker state, initialised by _init_index_worker in child processes +_worker_jedi_project: object | None = None +_worker_project_root_str: str | None = None + + +def _init_index_worker(project_root: str) -> None: + import jedi + + global _worker_jedi_project, _worker_project_root_str + _worker_jedi_project = jedi.Project(path=project_root) + _worker_project_root_str = project_root + + +def _resolve_definitions(ref: Name) -> list[Name]: + try: + inferred = ref.infer() + valid = [d for d in inferred if d.type in ("function", "class")] + if valid: + return valid + except Exception: + pass + + try: + result: list[Name] = ref.goto(follow_imports=True, follow_builtin_imports=False) + return result + except Exception: + return [] + + +def _is_valid_definition(definition: Name, caller_qualified_name: str, project_root_str: str) -> bool: + definition_path = definition.module_path + if definition_path is None: + return False + + if not str(definition_path).startswith(project_root_str + os.sep): + return False + + if path_belongs_to_site_packages(definition_path): + return False + + if not definition.full_name or not definition.full_name.startswith(definition.module_name): + return False + + if definition.type not in ("function", "class"): + return False + + try: + def_qn = get_qualified_name(definition.module_name, definition.full_name) + if def_qn == caller_qualified_name: + return False + except ValueError: + return False + + try: + from codeflash.optimization.function_context import belongs_to_function_qualified + + if belongs_to_function_qualified(definition, caller_qualified_name): + return False + except Exception: + pass + + return True + + +def _get_enclosing_function_qn(ref: Name) -> str | None: + try: + parent = ref.parent() + if parent is None or parent.type != "function": + return None + if not parent.full_name or not parent.full_name.startswith(parent.module_name): + return None + return get_qualified_name(parent.module_name, parent.full_name) + except (ValueError, AttributeError): + return None + + +def _analyze_file(file_path: Path, jedi_project: object, project_root_str: str) -> tuple[set[tuple[str, ...]], bool]: + """Pure Jedi analysis — no DB access. Returns (edges, had_error).""" + import jedi + + resolved = str(file_path.resolve()) + + try: + script = jedi.Script(path=file_path, project=jedi_project) + refs = script.get_names(all_scopes=True, definitions=False, references=True) + except Exception: + return set(), True + + edges: set[tuple[str, ...]] = set() + + for ref in refs: + try: + caller_qn = _get_enclosing_function_qn(ref) + if caller_qn is None: + continue + + definitions = _resolve_definitions(ref) + if not definitions: + continue + + definition = definitions[0] + definition_path = definition.module_path + if definition_path is None: + continue + + if not _is_valid_definition(definition, caller_qn, project_root_str): + continue + + edge_base = (resolved, caller_qn, str(definition_path)) + + if definition.type == "function": + callee_qn = get_qualified_name(definition.module_name, definition.full_name) + if len(callee_qn.split(".")) > 2: + continue + edges.add( + ( + *edge_base, + callee_qn, + definition.full_name, + definition.name, + definition.type, + definition.get_line_code(), + ) + ) + elif definition.type == "class": + init_qn = get_qualified_name(definition.module_name, f"{definition.full_name}.__init__") + if len(init_qn.split(".")) > 2: + continue + edges.add( + ( + *edge_base, + init_qn, + f"{definition.full_name}.__init__", + "__init__", + definition.type, + definition.get_line_code(), + ) + ) + except Exception: + continue + + return edges, False + + +def _index_file_worker(args: tuple[str, str]) -> tuple[str, str, set[tuple[str, ...]], bool]: + """Worker entry point for ProcessPoolExecutor.""" + file_path_str, file_hash = args + assert _worker_project_root_str is not None + edges, had_error = _analyze_file(Path(file_path_str), _worker_jedi_project, _worker_project_root_str) + return file_path_str, file_hash, edges, had_error + + +# --------------------------------------------------------------------------- + + +class ReferenceGraph: + SCHEMA_VERSION = 2 + + def __init__(self, project_root: Path, language: str = "python", db_path: Path | None = None) -> None: + import jedi + + self.project_root = project_root.resolve() + self.project_root_str = str(self.project_root) + self.language = language + self.jedi_project = jedi.Project(path=self.project_root) + + if db_path is None: + from codeflash.code_utils.compat import codeflash_cache_db + + db_path = codeflash_cache_db + + self.conn = sqlite3.connect(str(db_path)) + self.conn.execute("PRAGMA journal_mode=WAL") + self.indexed_file_hashes: dict[str, str] = {} + self._init_schema() + + def _init_schema(self) -> None: + cur = self.conn.cursor() + cur.execute("CREATE TABLE IF NOT EXISTS cg_schema_version (version INTEGER PRIMARY KEY)") + + row = cur.execute("SELECT version FROM cg_schema_version LIMIT 1").fetchone() + if row is None: + cur.execute("INSERT INTO cg_schema_version (version) VALUES (?)", (self.SCHEMA_VERSION,)) + elif row[0] != self.SCHEMA_VERSION: + for table in [ + "cg_call_edges", + "cg_indexed_files", + "cg_languages", + "cg_projects", + "cg_project_meta", + "indexed_files", + "call_edges", + ]: + cur.execute(f"DROP TABLE IF EXISTS {table}") + cur.execute("DELETE FROM cg_schema_version") + cur.execute("INSERT INTO cg_schema_version (version) VALUES (?)", (self.SCHEMA_VERSION,)) + + cur.execute( + """ + CREATE TABLE IF NOT EXISTS indexed_files ( + project_root TEXT NOT NULL, + language TEXT NOT NULL, + file_path TEXT NOT NULL, + file_hash TEXT NOT NULL, + PRIMARY KEY (project_root, language, file_path) + ) + """ + ) + cur.execute( + """ + CREATE TABLE IF NOT EXISTS call_edges ( + project_root TEXT NOT NULL, + language TEXT NOT NULL, + caller_file TEXT NOT NULL, + caller_qualified_name TEXT NOT NULL, + callee_file TEXT NOT NULL, + callee_qualified_name TEXT NOT NULL, + callee_fully_qualified_name TEXT NOT NULL, + callee_only_function_name TEXT NOT NULL, + callee_definition_type TEXT NOT NULL, + callee_source_line TEXT NOT NULL, + PRIMARY KEY (project_root, language, caller_file, caller_qualified_name, + callee_file, callee_qualified_name) + ) + """ + ) + cur.execute( + """ + CREATE INDEX IF NOT EXISTS idx_call_edges_caller + ON call_edges (project_root, language, caller_file, caller_qualified_name) + """ + ) + self.conn.commit() + + def get_callees( + self, file_path_to_qualified_names: dict[Path, set[str]] + ) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]: + file_path_to_function_source: dict[Path, set[FunctionSource]] = defaultdict(set) + function_source_list: list[FunctionSource] = [] + + all_caller_keys: list[tuple[str, str]] = [] + for file_path, qualified_names in file_path_to_qualified_names.items(): + resolved = str(file_path.resolve()) + self.ensure_file_indexed(file_path, resolved) + all_caller_keys.extend((resolved, qn) for qn in qualified_names) + + if not all_caller_keys: + return file_path_to_function_source, function_source_list + + cur = self.conn.cursor() + cur.execute("CREATE TEMP TABLE IF NOT EXISTS _caller_keys (caller_file TEXT, caller_qualified_name TEXT)") + cur.execute("DELETE FROM _caller_keys") + cur.executemany("INSERT INTO _caller_keys VALUES (?, ?)", all_caller_keys) + + rows = cur.execute( + """ + SELECT ce.callee_file, ce.callee_qualified_name, ce.callee_fully_qualified_name, + ce.callee_only_function_name, ce.callee_definition_type, ce.callee_source_line + FROM call_edges ce + INNER JOIN _caller_keys ck + ON ce.caller_file = ck.caller_file AND ce.caller_qualified_name = ck.caller_qualified_name + WHERE ce.project_root = ? AND ce.language = ? + """, + (self.project_root_str, self.language), + ).fetchall() + + for callee_file, callee_qn, callee_fqn, callee_name, callee_type, callee_src in rows: + callee_path = Path(callee_file) + fs = FunctionSource( + file_path=callee_path, + qualified_name=callee_qn, + fully_qualified_name=callee_fqn, + only_function_name=callee_name, + source_code=callee_src, + definition_type=callee_type, + ) + file_path_to_function_source[callee_path].add(fs) + function_source_list.append(fs) + + return file_path_to_function_source, function_source_list + + def count_callees_per_function( + self, file_path_to_qualified_names: dict[Path, set[str]] + ) -> dict[tuple[Path, str], int]: + all_caller_keys: list[tuple[Path, str, str]] = [] + for file_path, qualified_names in file_path_to_qualified_names.items(): + resolved = str(file_path.resolve()) + self.ensure_file_indexed(file_path, resolved) + all_caller_keys.extend((file_path, resolved, qn) for qn in qualified_names) + + if not all_caller_keys: + return {} + + cur = self.conn.cursor() + cur.execute("CREATE TEMP TABLE IF NOT EXISTS _count_keys (caller_file TEXT, caller_qualified_name TEXT)") + cur.execute("DELETE FROM _count_keys") + cur.executemany( + "INSERT INTO _count_keys VALUES (?, ?)", [(resolved, qn) for _, resolved, qn in all_caller_keys] + ) + + rows = cur.execute( + """ + SELECT ck.caller_file, ck.caller_qualified_name, COUNT(ce.rowid) + FROM _count_keys ck + LEFT JOIN call_edges ce + ON ce.caller_file = ck.caller_file AND ce.caller_qualified_name = ck.caller_qualified_name + AND ce.project_root = ? AND ce.language = ? + GROUP BY ck.caller_file, ck.caller_qualified_name + """, + (self.project_root_str, self.language), + ).fetchall() + + resolved_to_path: dict[str, Path] = {resolved: fp for fp, resolved, _ in all_caller_keys} + counts: dict[tuple[Path, str], int] = {} + for caller_file, caller_qn, cnt in rows: + counts[(resolved_to_path[caller_file], caller_qn)] = cnt + + return counts + + def ensure_file_indexed(self, file_path: Path, resolved: str | None = None) -> IndexResult: + if resolved is None: + resolved = str(file_path.resolve()) + + # Always read and hash the file before checking the cache so we detect on-disk changes + try: + content = file_path.read_text(encoding="utf-8") + except Exception: + return IndexResult(file_path=file_path, cached=False, num_edges=0, edges=(), cross_file_edges=0, error=True) + + file_hash = hashlib.sha256(content.encode("utf-8")).hexdigest() + + if self._is_file_cached(resolved, file_hash): + return IndexResult(file_path=file_path, cached=True, num_edges=0, edges=(), cross_file_edges=0, error=False) + + return self.index_file(file_path, file_hash, resolved) + + def index_file(self, file_path: Path, file_hash: str, resolved: str | None = None) -> IndexResult: + if resolved is None: + resolved = str(file_path.resolve()) + edges, had_error = _analyze_file(file_path, self.jedi_project, self.project_root_str) + if had_error: + logger.debug(f"ReferenceGraph: failed to parse {file_path}") + return self._persist_edges(file_path, resolved, file_hash, edges, had_error) + + def _persist_edges( + self, file_path: Path, resolved: str, file_hash: str, edges: set[tuple[str, ...]], had_error: bool + ) -> IndexResult: + cur = self.conn.cursor() + scope = (self.project_root_str, self.language) + + # Clear existing data for this file + cur.execute( + "DELETE FROM call_edges WHERE project_root = ? AND language = ? AND caller_file = ?", (*scope, resolved) + ) + cur.execute( + "DELETE FROM indexed_files WHERE project_root = ? AND language = ? AND file_path = ?", (*scope, resolved) + ) + + # Insert new edges if parsing succeeded + if not had_error and edges: + cur.executemany( + """ + INSERT OR REPLACE INTO call_edges + (project_root, language, caller_file, caller_qualified_name, + callee_file, callee_qualified_name, callee_fully_qualified_name, + callee_only_function_name, callee_definition_type, callee_source_line) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [(*scope, *edge) for edge in edges], + ) + + # Record that this file has been indexed + cur.execute( + "INSERT OR REPLACE INTO indexed_files (project_root, language, file_path, file_hash) VALUES (?, ?, ?, ?)", + (*scope, resolved, file_hash), + ) + + self.conn.commit() + self.indexed_file_hashes[resolved] = file_hash + + # Build summary for return value + edges_summary = tuple( + (caller_qn, callee_name, caller_file != callee_file) + for (caller_file, caller_qn, callee_file, _, _, callee_name, _, _) in edges + ) + cross_file_count = sum(is_cross_file for _, _, is_cross_file in edges_summary) + + return IndexResult( + file_path=file_path, + cached=False, + num_edges=len(edges), + edges=edges_summary, + cross_file_edges=cross_file_count, + error=had_error, + ) + + def build_index(self, file_paths: Iterable[Path], on_progress: Callable[[IndexResult], None] | None = None) -> None: + """Pre-index a batch of files, using multiprocessing for large uncached batches.""" + to_index: list[tuple[Path, str, str]] = [] + + for file_path in file_paths: + resolved = str(file_path.resolve()) + + try: + content = file_path.read_text(encoding="utf-8") + except Exception: + self._report_progress( + on_progress, + IndexResult( + file_path=file_path, cached=False, num_edges=0, edges=(), cross_file_edges=0, error=True + ), + ) + continue + + file_hash = hashlib.sha256(content.encode("utf-8")).hexdigest() + + # Check if already cached (in-memory or DB) + if self._is_file_cached(resolved, file_hash): + self._report_progress( + on_progress, + IndexResult( + file_path=file_path, cached=True, num_edges=0, edges=(), cross_file_edges=0, error=False + ), + ) + continue + + to_index.append((file_path, resolved, file_hash)) + + if not to_index: + return + + # Index uncached files + if len(to_index) >= _PARALLEL_THRESHOLD: + self._build_index_parallel(to_index, on_progress) + else: + for file_path, resolved, file_hash in to_index: + result = self.index_file(file_path, file_hash, resolved) + self._report_progress(on_progress, result) + + def _is_file_cached(self, resolved: str, file_hash: str) -> bool: + """Check if file is cached in memory or DB.""" + if self.indexed_file_hashes.get(resolved) == file_hash: + return True + + row = self.conn.execute( + "SELECT file_hash FROM indexed_files WHERE project_root = ? AND language = ? AND file_path = ?", + (self.project_root_str, self.language, resolved), + ).fetchone() + + if row and row[0] == file_hash: + self.indexed_file_hashes[resolved] = file_hash + return True + + return False + + def _report_progress(self, on_progress: Callable[[IndexResult], None] | None, result: IndexResult) -> None: + """Report progress if callback provided.""" + if on_progress is not None: + on_progress(result) + + def _build_index_parallel( + self, to_index: list[tuple[Path, str, str]], on_progress: Callable[[IndexResult], None] | None + ) -> None: + from concurrent.futures import ProcessPoolExecutor, as_completed + + max_workers = min(os.cpu_count() or 1, len(to_index), 8) + path_info: dict[str, tuple[Path, str]] = {resolved: (fp, fh) for fp, resolved, fh in to_index} + worker_args = [(resolved, fh) for _fp, resolved, fh in to_index] + + logger.debug(f"ReferenceGraph: indexing {len(to_index)} files across {max_workers} workers") + + try: + with ProcessPoolExecutor( + max_workers=max_workers, initializer=_init_index_worker, initargs=(self.project_root_str,) + ) as executor: + futures = {executor.submit(_index_file_worker, args): args[0] for args in worker_args} + + for future in as_completed(futures): + resolved = futures[future] + file_path, file_hash = path_info[resolved] + + try: + _, _, edges, had_error = future.result() + except Exception: + logger.debug(f"ReferenceGraph: worker failed for {file_path}") + self._persist_edges(file_path, resolved, file_hash, set(), had_error=True) + self._report_progress( + on_progress, + IndexResult( + file_path=file_path, cached=False, num_edges=0, edges=(), cross_file_edges=0, error=True + ), + ) + continue + + if had_error: + logger.debug(f"ReferenceGraph: failed to parse {file_path}") + + result = self._persist_edges(file_path, resolved, file_hash, edges, had_error) + self._report_progress(on_progress, result) + + except Exception: + logger.debug("ReferenceGraph: parallel indexing failed, falling back to sequential") + self._fallback_sequential_index(to_index, on_progress) + + def _fallback_sequential_index( + self, to_index: list[tuple[Path, str, str]], on_progress: Callable[[IndexResult], None] | None + ) -> None: + """Fallback to sequential indexing when parallel processing fails.""" + for file_path, resolved, file_hash in to_index: + # Skip files already persisted before the failure + if resolved in self.indexed_file_hashes: + continue + result = self.index_file(file_path, file_hash, resolved) + self._report_progress(on_progress, result) + + def close(self) -> None: + self.conn.close() diff --git a/codeflash/languages/python/support.py b/codeflash/languages/python/support.py index ddff51b8f..b026e99e5 100644 --- a/codeflash/languages/python/support.py +++ b/codeflash/languages/python/support.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from collections.abc import Sequence + from codeflash.languages.base import DependencyResolver from codeflash.models.models import FunctionSource logger = logging.getLogger(__name__) @@ -751,6 +752,15 @@ def ensure_runtime_environment(self, project_root: Path) -> bool: """ return True + def create_dependency_resolver(self, project_root: Path) -> DependencyResolver | None: + from codeflash.languages.python.reference_graph import ReferenceGraph + + try: + return ReferenceGraph(project_root, language=self.language.value) + except Exception: + logger.debug("Failed to initialize ReferenceGraph, falling back to per-function Jedi analysis") + return None + def instrument_existing_test( self, test_path: Path, diff --git a/codeflash/models/models.py b/codeflash/models/models.py index b7aeb43b1..697601403 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -25,7 +25,6 @@ from re import Pattern from typing import Any, NamedTuple, Optional, cast -from jedi.api.classes import Name from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, model_validator from pydantic.dataclasses import dataclass @@ -136,14 +135,14 @@ class CoverReturnCode(IntEnum): ERROR = 2 -@dataclass(frozen=True, config={"arbitrary_types_allowed": True}) +@dataclass(frozen=True) class FunctionSource: file_path: Path qualified_name: str fully_qualified_name: str only_function_name: str source_code: str - jedi_definition: Name | None = None # None for non-Python languages + definition_type: str | None = None # e.g. "function", "class"; None for non-Python languages def __eq__(self, other: object) -> bool: if not isinstance(other, FunctionSource): diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index c24d84ae5..55dfa314e 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -140,6 +140,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import Result + from codeflash.languages.base import DependencyResolver from codeflash.models.models import ( BenchmarkKey, CodeStringsMarkdown, @@ -443,6 +444,7 @@ def __init__( total_benchmark_timings: dict[BenchmarkKey, int] | None = None, args: Namespace | None = None, replay_tests_dir: Path | None = None, + call_graph: DependencyResolver | None = None, ) -> None: self.project_root = test_cfg.project_root_path.resolve() self.test_cfg = test_cfg @@ -488,6 +490,7 @@ def __init__( self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {} self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {} self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None + self.call_graph = call_graph n_tests = get_effort_value(EffortKeys.N_GENERATED_TESTS, self.effort) self.executor = concurrent.futures.ThreadPoolExecutor( max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4 @@ -1492,8 +1495,8 @@ def replace_function_and_helpers_with_optimized_code( self.function_to_optimize.qualified_name ) for helper_function in code_context.helper_functions: - # Skip class definitions (jedi_definition may be None for non-Python languages) - if helper_function.jedi_definition is None or helper_function.jedi_definition.type != "class": + # Skip class definitions (definition_type may be None for non-Python languages) + if helper_function.definition_type != "class": read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name) for module_abspath, qualified_names in read_writable_functions_by_file_path.items(): did_update |= replace_function_definitions_in_module( @@ -1514,7 +1517,7 @@ def replace_function_and_helpers_with_optimized_code( def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: try: new_code_ctx = code_context_extractor.get_code_optimization_context( - self.function_to_optimize, self.project_root + self.function_to_optimize, self.project_root, call_graph=self.call_graph ) except ValueError as e: return Failure(str(e)) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 06540ca85..3211ab59b 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -11,7 +11,13 @@ from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.api.cfapi import send_completion_email -from codeflash.cli_cmds.console import console, logger, progress_bar +from codeflash.cli_cmds.console import ( # noqa: F401 + call_graph_live_display, + call_graph_summary, + console, + logger, + progress_bar, +) from codeflash.code_utils import env_utils from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft @@ -24,7 +30,7 @@ ) from codeflash.code_utils.time_utils import humanize_runtime from codeflash.either import is_successful -from codeflash.languages import is_javascript, set_current_language +from codeflash.languages import current_language_support, is_javascript, set_current_language from codeflash.models.models import ValidCode from codeflash.telemetry.posthog_cf import ph from codeflash.verification.verification_utils import TestConfig @@ -35,6 +41,7 @@ from codeflash.benchmarking.function_ranker import FunctionRanker from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.languages.base import DependencyResolver from codeflash.models.models import BenchmarkKey, FunctionCalledInTest from codeflash.optimization.function_optimizer import FunctionOptimizer @@ -241,6 +248,7 @@ def create_function_optimizer( total_benchmark_timings: dict[BenchmarkKey, float] | None = None, original_module_ast: ast.Module | None = None, original_module_path: Path | None = None, + call_graph: DependencyResolver | None = None, ) -> FunctionOptimizer | None: from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast from codeflash.optimization.function_optimizer import FunctionOptimizer @@ -279,6 +287,7 @@ def create_function_optimizer( function_benchmark_timings=function_specific_timings, total_benchmark_timings=total_benchmark_timings if function_specific_timings else None, replay_tests_dir=self.replay_tests_dir, + call_graph=call_graph, ) def prepare_module_for_optimization( @@ -422,7 +431,10 @@ def display_global_ranking( console.print(f"[dim]... and {len(globally_ranked) - display_count} more functions[/dim]") def rank_all_functions_globally( - self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], trace_file_path: Path | None + self, + file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], + trace_file_path: Path | None, + call_graph: DependencyResolver | None = None, ) -> list[tuple[Path, FunctionToOptimize]]: """Rank all functions globally across all files based on trace data. @@ -442,8 +454,10 @@ def rank_all_functions_globally( for file_path, functions in file_to_funcs_to_optimize.items(): all_functions.extend((file_path, func) for func in functions) - # If no trace file, return in original order + # If no trace file, rank by dependency count if call graph is available if not trace_file_path or not trace_file_path.exists(): + if call_graph is not None: + return self.rank_by_dependency_count(all_functions, call_graph) logger.debug("No trace file available, using original function order") return all_functions @@ -494,6 +508,19 @@ def rank_all_functions_globally( else: return globally_ranked + def rank_by_dependency_count( + self, all_functions: list[tuple[Path, FunctionToOptimize]], call_graph: DependencyResolver + ) -> list[tuple[Path, FunctionToOptimize]]: + file_to_qns: dict[Path, set[str]] = defaultdict(set) + for file_path, func in all_functions: + file_to_qns[file_path].add(func.qualified_name) + callee_counts = call_graph.count_callees_per_function(dict(file_to_qns)) + ranked = sorted( + enumerate(all_functions), key=lambda x: (-callee_counts.get((x[1][0], x[1][1].qualified_name), 0), x[0]) + ) + logger.debug(f"Ranked {len(ranked)} functions by dependency count (most complex first)") + return [item for _, item in ranked] + def run(self) -> None: from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint @@ -536,16 +563,33 @@ def run(self) -> None: if self.args.all: three_min_in_ns = int(1.8e11) console.rule() - pr_message = ( - "\nCodeflash will keep opening pull requests as it finds optimizations." if not self.args.no_pr else "" - ) logger.info( - f"It might take about {humanize_runtime(num_optimizable_functions * three_min_in_ns)} to fully optimize this project.{pr_message}" + f"It might take about {humanize_runtime(num_optimizable_functions * three_min_in_ns)} to fully optimize this project." ) + if not self.args.no_pr: + logger.info("Codeflash will keep opening pull requests as it finds optimizations.") + console.rule() function_benchmark_timings, total_benchmark_timings = self.run_benchmarks( file_to_funcs_to_optimize, num_optimizable_functions ) + + # Create a language-specific dependency resolver (e.g. Jedi-based call graph for Python) + # Skip in CI — the cache DB doesn't persist between runs on ephemeral runners + lang_support = current_language_support() + resolver = None + # CURRENTLY DISABLED: The resolver is currently not used for anything until i clean up the repo structure for python + # if lang_support and not env_utils.is_ci(): + # resolver = lang_support.create_dependency_resolver(self.args.project_root) + + # if resolver is not None and lang_support is not None and file_to_funcs_to_optimize: + # supported_exts = lang_support.file_extensions + # source_files = [f for f in file_to_funcs_to_optimize if f.suffix in supported_exts] + # with call_graph_live_display(len(source_files), project_root=self.args.project_root) as on_progress: + # resolver.build_index(source_files, on_progress=on_progress) + # console.rule() + # call_graph_summary(resolver, file_to_funcs_to_optimize) + optimizations_found: int = 0 self.test_cfg.concolic_test_root_dir = Path( tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_") @@ -561,7 +605,9 @@ def run(self) -> None: self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_root) # GLOBAL RANKING: Rank all functions together before optimizing - globally_ranked_functions = self.rank_all_functions_globally(file_to_funcs_to_optimize, trace_file_path) + globally_ranked_functions = self.rank_all_functions_globally( + file_to_funcs_to_optimize, trace_file_path, call_graph=resolver + ) # Cache for module preparation (avoid re-parsing same files) prepared_modules: dict[Path, tuple[dict[Path, ValidCode], ast.Module | None]] = {} @@ -593,6 +639,7 @@ def run(self) -> None: total_benchmark_timings=total_benchmark_timings, original_module_ast=original_module_ast, original_module_path=original_module_path, + call_graph=resolver, ) if function_optimizer is None: continue @@ -651,6 +698,9 @@ def run(self) -> None: else: logger.warning("⚠️ Failed to send completion email. Status") finally: + if resolver is not None: + resolver.close() + if function_optimizer: function_optimizer.cleanup_generated_files() diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 84da13762..4dfddb4f7 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -767,11 +767,204 @@ def helper_method(self): assert hashing_context.strip() == expected_hashing_context.strip() -def test_example_class_token_limit_3(tmp_path: Path) -> None: +def test_example_class_token_limit_1(tmp_path: Path) -> None: + docstring_filler = " ".join( + ["This is a long docstring that will be used to fill up the token limit." for _ in range(4000)] + ) + code = f""" +class MyClass: + \"\"\"A class with a helper method. +{docstring_filler}\"\"\" + def __init__(self): + self.x = 1 + def target_method(self): + \"\"\"Docstring for target method\"\"\" + y = HelperClass().helper_method() + +class HelperClass: + \"\"\"A helper class for MyClass.\"\"\" + def __init__(self): + \"\"\"Initialize the HelperClass.\"\"\" + self.x = 1 + def __repr__(self): + \"\"\"Return a string representation of the HelperClass.\"\"\" + return "HelperClass" + str(self.x) + def helper_method(self): + return self.x +""" + # Create a temporary Python file using pytest's tmp_path fixture + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="MyClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context + # In this scenario, the read-only code context is too long, so the read-only docstrings are removed. + expected_read_write_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} +class MyClass: + def __init__(self): + self.x = 1 + def target_method(self): + \"\"\"Docstring for target method\"\"\" + y = HelperClass().helper_method() + +class HelperClass: + def __init__(self): + \"\"\"Initialize the HelperClass.\"\"\" + self.x = 1 + def helper_method(self): + return self.x +``` +""" + expected_read_only_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} +class MyClass: + pass + +class HelperClass: + def __repr__(self): + return "HelperClass" + str(self.x) +``` +""" + expected_hashing_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} +class MyClass: + + def target_method(self): + y = HelperClass().helper_method() + +class HelperClass: + + def helper_method(self): + return self.x +``` +""" + assert read_write_context.markdown.strip() == expected_read_write_context.strip() + assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() + + +def test_example_class_token_limit_2(tmp_path: Path) -> None: string_filler = " ".join( ["This is a long string that will be used to fill up the token limit." for _ in range(1000)] ) code = f""" +class MyClass: + \"\"\"A class with a helper method. \"\"\" + def __init__(self): + self.x = 1 + def target_method(self): + \"\"\"Docstring for target method\"\"\" + y = HelperClass().helper_method() +x = '{string_filler}' + +class HelperClass: + \"\"\"A helper class for MyClass.\"\"\" + def __init__(self): + \"\"\"Initialize the HelperClass.\"\"\" + self.x = 1 + def __repr__(self): + \"\"\"Return a string representation of the HelperClass.\"\"\" + return "HelperClass" + str(self.x) + def helper_method(self): + return self.x +""" + # Create a temporary Python file using pytest's tmp_path fixture + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="MyClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000) + read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + hashing_context = code_ctx.hashing_code_context + # In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely. + expected_read_write_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} +class MyClass: + def __init__(self): + self.x = 1 + def target_method(self): + \"\"\"Docstring for target method\"\"\" + y = HelperClass().helper_method() + +class HelperClass: + def __init__(self): + \"\"\"Initialize the HelperClass.\"\"\" + self.x = 1 + def helper_method(self): + return self.x +``` +""" + expected_read_only_context = f'''```python:{file_path.relative_to(opt.args.project_root)} +class MyClass: + """A class with a helper method. """ + +class HelperClass: + """A helper class for MyClass.""" + def __repr__(self): + """Return a string representation of the HelperClass.""" + return "HelperClass" + str(self.x) +``` +''' + expected_hashing_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} +class MyClass: + + def target_method(self): + y = HelperClass().helper_method() + +class HelperClass: + + def helper_method(self): + return self.x +``` +""" + assert read_write_context.markdown.strip() == expected_read_write_context.strip() + assert read_only_context.strip() == expected_read_only_context.strip() + assert hashing_context.strip() == expected_hashing_context.strip() + + +def test_example_class_token_limit_3(tmp_path: Path) -> None: + string_filler = " ".join( + ["This is a long string that will be used to fill up the token limit." for _ in range(4000)] + ) + code = f""" class MyClass: \"\"\"A class with a helper method. \"\"\" def __init__(self): @@ -819,7 +1012,7 @@ def helper_method(self): def test_example_class_token_limit_4(tmp_path: Path) -> None: string_filler = " ".join( - ["This is a long string that will be used to fill up the token limit." for _ in range(1000)] + ["This is a long string that will be used to fill up the token limit." for _ in range(4000)] ) code = f""" class MyClass: diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index eccdc4e03..fbca6d71e 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -1,6 +1,5 @@ from __future__ import annotations -import dataclasses import os import re from collections import defaultdict @@ -19,28 +18,13 @@ replace_functions_in_file, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent +from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent, FunctionSource from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig os.environ["CODEFLASH_API_KEY"] = "cf-test-key" -@dataclasses.dataclass -class JediDefinition: - type: str - - -@dataclasses.dataclass -class FakeFunctionSource: - file_path: Path - qualified_name: str - fully_qualified_name: str - only_function_name: str - source_code: str - jedi_definition: JediDefinition - - class Args: disable_imports_sorting = True formatter_cmds = ["disabled"] @@ -1137,7 +1121,7 @@ def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]: preexisting_objects = find_preexisting_objects(original_code) helper_functions = [ - FakeFunctionSource( + FunctionSource( file_path=Path( "/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/codeflash/verification/test_results.py" ), @@ -1145,7 +1129,7 @@ def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]: fully_qualified_name="codeflash.verification.test_results.TestType", only_function_name="TestType", source_code="", - jedi_definition=JediDefinition(type="class"), + definition_type="class", ) ] @@ -1160,7 +1144,7 @@ def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]: helper_functions_by_module_abspath = defaultdict(set) for helper_function in helper_functions: - if helper_function.jedi_definition.type != "class": + if helper_function.definition_type != "class": helper_functions_by_module_abspath[helper_function.file_path].add(helper_function.qualified_name) for module_abspath, qualified_names in helper_functions_by_module_abspath.items(): new_code: str = replace_functions_and_add_imports( @@ -1352,21 +1336,21 @@ def cosine_similarity_top_k( preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) helper_functions = [ - FakeFunctionSource( + FunctionSource( file_path=(Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve(), qualified_name="Matrix", fully_qualified_name="code_to_optimize.math_utils.Matrix", only_function_name="Matrix", source_code="", - jedi_definition=JediDefinition(type="class"), + definition_type="class", ), - FakeFunctionSource( + FunctionSource( file_path=(Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve(), qualified_name="cosine_similarity", fully_qualified_name="code_to_optimize.math_utils.cosine_similarity", only_function_name="cosine_similarity", source_code="", - jedi_definition=JediDefinition(type="function"), + definition_type="function", ), ] @@ -1425,7 +1409,7 @@ def cosine_similarity_top_k( ) helper_functions_by_module_abspath = defaultdict(set) for helper_function in helper_functions: - if helper_function.jedi_definition.type != "class": + if helper_function.definition_type != "class": helper_functions_by_module_abspath[helper_function.file_path].add(helper_function.qualified_name) for module_abspath, qualified_names in helper_functions_by_module_abspath.items(): new_helper_code: str = replace_functions_and_add_imports( diff --git a/tests/test_function_dependencies.py b/tests/test_function_dependencies.py index f51780f92..988f60b7b 100644 --- a/tests/test_function_dependencies.py +++ b/tests/test_function_dependencies.py @@ -151,10 +151,9 @@ def test_class_method_dependencies() -> None: # The code_context above should have the topologicalSortUtil function in it assert len(code_context.helper_functions) == 1 assert ( - code_context.helper_functions[0].jedi_definition.full_name - == "test_function_dependencies.Graph.topologicalSortUtil" + code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.Graph.topologicalSortUtil" ) - assert code_context.helper_functions[0].jedi_definition.name == "topologicalSortUtil" + assert code_context.helper_functions[0].only_function_name == "topologicalSortUtil" assert ( code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.Graph.topologicalSortUtil" ) diff --git a/tests/test_reference_graph.py b/tests/test_reference_graph.py new file mode 100644 index 000000000..6e3ab0c65 --- /dev/null +++ b/tests/test_reference_graph.py @@ -0,0 +1,475 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from pathlib import Path + +from codeflash.languages.base import IndexResult +from codeflash.languages.python.reference_graph import ReferenceGraph + + +@pytest.fixture +def project(tmp_path: Path) -> Path: + project_root = tmp_path / "project" + project_root.mkdir() + return project_root + + +@pytest.fixture +def db_path(tmp_path: Path) -> Path: + return tmp_path / "cache.db" + + +def write_file(project: Path, name: str, content: str) -> Path: + fp = project / name + fp.write_text(content, encoding="utf-8") + return fp + + +# --------------------------------------------------------------------------- +# Unit tests +# --------------------------------------------------------------------------- + + +def test_simple_function_call(project: Path, db_path: Path) -> None: + write_file( + project, + "mod.py", + """\ +def helper(): + return 1 + +def caller(): + return helper() +""", + ) + cg = ReferenceGraph(project, db_path=db_path) + try: + _, result_list = cg.get_callees({project / "mod.py": {"caller"}}) + callee_qns = {fs.qualified_name for fs in result_list} + assert "helper" in callee_qns + finally: + cg.close() + + +def test_cross_file_call(project: Path, db_path: Path) -> None: + write_file( + project, + "utils.py", + """\ +def utility(): + return 42 +""", + ) + write_file( + project, + "main.py", + """\ +from utils import utility + +def caller(): + return utility() +""", + ) + cg = ReferenceGraph(project, db_path=db_path) + try: + _, result_list = cg.get_callees({project / "main.py": {"caller"}}) + callee_qns = {fs.qualified_name for fs in result_list} + assert "utility" in callee_qns + # Should be in the utils.py file + callee_files = {fs.file_path.resolve() for fs in result_list if fs.qualified_name == "utility"} + assert (project / "utils.py").resolve() in callee_files + finally: + cg.close() + + +def test_class_instantiation(project: Path, db_path: Path) -> None: + write_file( + project, + "mod.py", + """\ +class MyClass: + def __init__(self): + pass + +def caller(): + obj = MyClass() + return obj +""", + ) + cg = ReferenceGraph(project, db_path=db_path) + try: + _, result_list = cg.get_callees({project / "mod.py": {"caller"}}) + callee_types = {fs.definition_type for fs in result_list} + assert "class" in callee_types + finally: + cg.close() + + +def test_nested_function_excluded(project: Path, db_path: Path) -> None: + write_file( + project, + "mod.py", + """\ +def caller(): + def inner(): + return 1 + return inner() +""", + ) + cg = ReferenceGraph(project, db_path=db_path) + try: + _, result_list = cg.get_callees({project / "mod.py": {"caller"}}) + assert len(result_list) == 0 + finally: + cg.close() + + +def test_module_level_not_tracked(project: Path, db_path: Path) -> None: + write_file( + project, + "mod.py", + """\ +def helper(): + return 1 + +x = helper() +""", + ) + cg = ReferenceGraph(project, db_path=db_path) + try: + # Module level calls have no enclosing function, so no edges + _, result_list = cg.get_callees({project / "mod.py": {"helper"}}) + # helper itself doesn't call anything + assert len(result_list) == 0 + finally: + cg.close() + + +def test_site_packages_excluded(project: Path, db_path: Path) -> None: + write_file( + project, + "mod.py", + """\ +import os + +def caller(): + return os.path.join("a", "b") +""", + ) + cg = ReferenceGraph(project, db_path=db_path) + try: + _, result_list = cg.get_callees({project / "mod.py": {"caller"}}) + # os.path.join is stdlib, should not appear + assert len(result_list) == 0 + finally: + cg.close() + + +def test_empty_file(project: Path, db_path: Path) -> None: + write_file(project, "mod.py", "") + cg = ReferenceGraph(project, db_path=db_path) + try: + _, result_list = cg.get_callees({project / "mod.py": set()}) + assert len(result_list) == 0 + finally: + cg.close() + + +def test_syntax_error_file(project: Path, db_path: Path) -> None: + write_file(project, "mod.py", "def broken(\n") + cg = ReferenceGraph(project, db_path=db_path) + try: + _, result_list = cg.get_callees({project / "mod.py": {"broken"}}) + assert len(result_list) == 0 + finally: + cg.close() + + +# --------------------------------------------------------------------------- +# Caching tests +# --------------------------------------------------------------------------- + + +def test_caching_no_reindex(project: Path, db_path: Path) -> None: + write_file( + project, + "mod.py", + """\ +def helper(): + return 1 + +def caller(): + return helper() +""", + ) + cg = ReferenceGraph(project, db_path=db_path) + try: + cg.get_callees({project / "mod.py": {"caller"}}) + # Second call should use in-memory cache (hash unchanged) + resolved = str((project / "mod.py").resolve()) + assert resolved in cg.indexed_file_hashes + old_hash = cg.indexed_file_hashes[resolved] + cg.get_callees({project / "mod.py": {"caller"}}) + assert cg.indexed_file_hashes[resolved] == old_hash + finally: + cg.close() + + +def test_incremental_update_on_change(project: Path, db_path: Path) -> None: + fp = write_file( + project, + "mod.py", + """\ +def helper(): + return 1 + +def caller(): + return helper() +""", + ) + cg = ReferenceGraph(project, db_path=db_path) + try: + _, result_list = cg.get_callees({project / "mod.py": {"caller"}}) + assert any(fs.qualified_name == "helper" for fs in result_list) + + # Modify the file — caller no longer calls helper + fp.write_text( + """\ +def helper(): + return 1 + +def new_helper(): + return 2 + +def caller(): + return new_helper() +""", + encoding="utf-8", + ) + _, result_list = cg.get_callees({project / "mod.py": {"caller"}}) + callee_qns = {fs.qualified_name for fs in result_list} + assert "new_helper" in callee_qns + finally: + cg.close() + + +def test_persistence_across_sessions(project: Path, db_path: Path) -> None: + write_file( + project, + "mod.py", + """\ +def helper(): + return 1 + +def caller(): + return helper() +""", + ) + # First session: index the file + cg1 = ReferenceGraph(project, db_path=db_path) + try: + _, result_list = cg1.get_callees({project / "mod.py": {"caller"}}) + assert any(fs.qualified_name == "helper" for fs in result_list) + finally: + cg1.close() + + # Second session: should read from DB without re-indexing + cg2 = ReferenceGraph(project, db_path=db_path) + try: + assert len(cg2.indexed_file_hashes) == 0 # in-memory cache is empty + _, result_list = cg2.get_callees({project / "mod.py": {"caller"}}) + assert any(fs.qualified_name == "helper" for fs in result_list) + finally: + cg2.close() + + +def test_build_index_with_progress(project: Path, db_path: Path) -> None: + write_file( + project, + "a.py", + """\ +def helper_a(): + return 1 + +def caller_a(): + return helper_a() +""", + ) + write_file( + project, + "b.py", + """\ +from a import helper_a + +def caller_b(): + return helper_a() +""", + ) + + cg = ReferenceGraph(project, db_path=db_path) + try: + progress_calls: list[IndexResult] = [] + files = [project / "a.py", project / "b.py"] + cg.build_index(files, on_progress=progress_calls.append) + + # Callback fired once per file + assert len(progress_calls) == 2 + + # Verify IndexResult fields for freshly indexed files + for result in progress_calls: + assert isinstance(result, IndexResult) + assert not result.error + assert not result.cached + assert result.num_edges > 0 + assert len(result.edges) == result.num_edges + assert result.cross_file_edges >= 0 + + # Files are now indexed — get_callees should return correct results + _, result_list = cg.get_callees({project / "a.py": {"caller_a"}}) + callee_qns = {fs.qualified_name for fs in result_list} + assert "helper_a" in callee_qns + finally: + cg.close() + + +def test_build_index_cached_results(project: Path, db_path: Path) -> None: + write_file( + project, + "a.py", + """\ +def helper_a(): + return 1 + +def caller_a(): + return helper_a() +""", + ) + write_file( + project, + "b.py", + """\ +from a import helper_a + +def caller_b(): + return helper_a() +""", + ) + + cg = ReferenceGraph(project, db_path=db_path) + try: + files = [project / "a.py", project / "b.py"] + # First pass — fresh indexing + cg.build_index(files) + + # Second pass — should all be cached + cached_results: list[IndexResult] = [] + cg.build_index(files, on_progress=cached_results.append) + + assert len(cached_results) == 2 + for result in cached_results: + assert result.cached + assert not result.error + assert result.num_edges == 0 + assert result.edges == () + assert result.cross_file_edges == 0 + finally: + cg.close() + + +def test_cross_file_edges_tracked(project: Path, db_path: Path) -> None: + write_file( + project, + "utils.py", + """\ +def utility(): + return 42 +""", + ) + write_file( + project, + "main.py", + """\ +from utils import utility + +def caller(): + return utility() +""", + ) + + cg = ReferenceGraph(project, db_path=db_path) + try: + progress_calls: list[IndexResult] = [] + cg.build_index([project / "utils.py", project / "main.py"], on_progress=progress_calls.append) + + # main.py should have cross-file edges (calls into utils.py) + main_result = next(r for r in progress_calls if r.file_path.name == "main.py") + assert main_result.cross_file_edges > 0 + # At least one edge tuple should have is_cross_file=True + assert any(is_cross_file for _, _, is_cross_file in main_result.edges) + finally: + cg.close() + + +def test_count_callees_per_function(project: Path, db_path: Path) -> None: + write_file( + project, + "mod.py", + """\ +def helper_a(): + return 1 + +def helper_b(): + return 2 + +def caller_one(): + return helper_a() + helper_b() + +def caller_two(): + return helper_a() + +def leaf(): + return 42 +""", + ) + + cg = ReferenceGraph(project, db_path=db_path) + try: + cg.build_index([project / "mod.py"]) + mod_path = project / "mod.py" + counts = cg.count_callees_per_function({mod_path: {"caller_one", "caller_two", "leaf"}}) + assert counts[(mod_path, "caller_one")] == 2 + assert counts[(mod_path, "caller_two")] == 1 + assert counts[(mod_path, "leaf")] == 0 + finally: + cg.close() + + +def test_same_file_edges_not_cross_file(project: Path, db_path: Path) -> None: + write_file( + project, + "mod.py", + """\ +def helper(): + return 1 + +def caller(): + return helper() +""", + ) + + cg = ReferenceGraph(project, db_path=db_path) + try: + progress_calls: list[IndexResult] = [] + cg.build_index([project / "mod.py"], on_progress=progress_calls.append) + + assert len(progress_calls) == 1 + result = progress_calls[0] + assert result.cross_file_edges == 0 + # All edges should have is_cross_file=False + assert all(not is_cross_file for _, _, is_cross_file in result.edges) + finally: + cg.close() diff --git a/tests/test_unused_helper_revert.py b/tests/test_unused_helper_revert.py index bfc75642c..2a4efae3d 100644 --- a/tests/test_unused_helper_revert.py +++ b/tests/test_unused_helper_revert.py @@ -918,7 +918,7 @@ def local_helper(self, x): "only_function_name": "global_helper_1", "fully_qualified_name": "main.global_helper_1", "file_path": main_file, - "jedi_definition": type("MockJedi", (), {"type": "function"})(), + "definition_type": "function", }, )(), type( @@ -929,7 +929,7 @@ def local_helper(self, x): "only_function_name": "global_helper_2", "fully_qualified_name": "main.global_helper_2", "file_path": main_file, - "jedi_definition": type("MockJedi", (), {"type": "function"})(), + "definition_type": "function", }, )(), ]