diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index 69c8bbef2..3eac31934 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -3,7 +3,7 @@ import ast import hashlib import os -from collections import defaultdict +from collections import defaultdict, deque from itertools import chain from typing import TYPE_CHECKING @@ -746,10 +746,15 @@ def collect_type_names_from_annotation(node: ast.expr | None) -> set[str]: def extract_init_stub_from_class(class_name: str, module_source: str, module_tree: ast.Module) -> str | None: class_node = None - for node in ast.walk(module_tree): - if isinstance(node, ast.ClassDef) and node.name == class_name: - class_node = node + # Use a deque-based BFS to find the first matching ClassDef (preserves ast.walk order) + q: deque[ast.AST] = deque([module_tree]) + while q: + candidate = q.popleft() + if isinstance(candidate, ast.ClassDef) and candidate.name == class_name: + class_node = candidate break + q.extend(ast.iter_child_nodes(candidate)) + if class_node is None: return None @@ -757,22 +762,33 @@ def extract_init_stub_from_class(class_name: str, module_source: str, module_tre relevant_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = [] for item in class_node.body: if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): - if item.name in ("__init__", "__post_init__") or any( - (isinstance(d, ast.Name) and d.id == "property") - or (isinstance(d, ast.Attribute) and d.attr == "property") - for d in item.decorator_list - ): + is_relevant = False + if item.name in ("__init__", "__post_init__"): + is_relevant = True + else: + # Check decorators explicitly to avoid generator overhead + for d in item.decorator_list: + if (isinstance(d, ast.Name) and d.id == "property") or ( + isinstance(d, ast.Attribute) and d.attr == "property" + ): + is_relevant = True + break + if is_relevant: relevant_nodes.append(item) if not relevant_nodes: return None snippets: list[str] = [] - for node in relevant_nodes: - start = node.lineno - if node.decorator_list: - start = min(d.lineno for d in node.decorator_list) - snippets.append("\n".join(lines[start - 1 : node.end_lineno])) + for fn_node in relevant_nodes: + start = fn_node.lineno + if fn_node.decorator_list: + # Compute minimum decorator lineno with an explicit loop (avoids generator/min overhead) + m = start + for d in fn_node.decorator_list: + m = min(m, d.lineno) + start = m + snippets.append("\n".join(lines[start - 1 : fn_node.end_lineno])) return f"class {class_name}:\n" + "\n".join(snippets)