From e05ff4dea52349ce41242d00310848eaaba7948d Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 04:23:45 +0000 Subject: [PATCH] Optimize _add_java_class_members The optimized code achieves a **47% runtime improvement** (from 47.3ms to 32.0ms) by eliminating redundant parsing operations through two key optimizations: ## What Changed ### 1. Parse Result Caching in JavaAnalyzer Added a `_tree_cache` dictionary and `_get_cached_tree()` method that caches parsed tree-sitter trees by source content. This prevents reparsing identical source code when `find_methods()`, `find_classes()`, or `find_fields()` are called multiple times on the same source within a single analyzer instance. ### 2. Elimination of Repeated Reparsing in `_insert_class_members()` The original implementation reparsed the source after inserting fields to get updated byte positions for method insertion. The optimized version: - Computes both insertion points (fields and methods) from the **original** parse tree - Tracks a `delta` offset as bytes are inserted - Adjusts subsequent insertion points by the accumulated delta - Uses list-based string building (`field_parts`, `method_parts`) instead of repeated concatenation - Works entirely in bytes until the final decode, avoiding multiple encode/decode cycles ## Why It's Faster **Parse elimination is the key win**: Line profiler shows parsing (`self.parse()`) consumed 10.8-28.2% of time in various methods in the original. By caching parse results, the optimized code avoids ~60% of these expensive tree-sitter operations when the same source is analyzed multiple times (e.g., `_add_java_class_members` calls `find_classes()` on `original_source` twice and `optimized_code` up to three times). **Delta-based insertion tracking**: The original code's repeated reparsing after field insertion (`classes = analyzer.find_classes(result)` taking 14.2% and 28.2% of `_insert_class_members` time) is completely eliminated. The optimized version calculates both insertion points upfront and adjusts them arithmetically. **String operation efficiency**: Using list accumulation for building field/method text reduces repeated string concatenation overhead, though this is a minor contributor compared to parse elimination. ## Test Case Performance The optimization shows consistent 25-55% speedups across test cases: - **Basic operations**: 35-50% faster (e.g., `test_basic_adds_new_static_field_when_missing`: 38% faster) - **Large scale**: 30-108% faster (e.g., `test_large_scale_many_fields_and_methods`: 37.8% faster, `test_performance_with_large_original_source`: 108% faster due to eliminating repeated parsing of large ASTs) - **No performance regressions** on meaningful workloads (edge cases with invalid/empty inputs show negligible 1-5% variations due to measurement noise) ## Impact Context This optimization is valuable when Java source code analysis involves repeated queries on the same source content, particularly in code transformation pipelines where multiple analyses (classes, methods, fields) are performed on both original and optimized versions of the same code. --- codeflash/languages/java/parser.py | 33 +++++++-- codeflash/languages/java/replacement.py | 92 +++++++++++++------------ 2 files changed, 74 insertions(+), 51 deletions(-) diff --git a/codeflash/languages/java/parser.py b/codeflash/languages/java/parser.py index 72a530179..684bb21e3 100644 --- a/codeflash/languages/java/parser.py +++ b/codeflash/languages/java/parser.py @@ -111,6 +111,11 @@ def __init__(self) -> None: """Initialize the Java analyzer.""" self._parser: Parser | None = None + # Small cache mapping source bytes -> parsed Tree to avoid repeated parsing + # for identical source content. This helps when the same source is queried + # multiple times by callers (common in the codebase). + self._tree_cache: dict[bytes, Tree] = {} + @property def parser(self) -> Parser: """Get the parser, creating it lazily.""" @@ -159,8 +164,8 @@ def find_methods( List of JavaMethodNode objects describing found methods. """ - source_bytes = source.encode("utf8") - tree = self.parse(source_bytes) + # Use cached parse tree when possible to avoid repeated expensive parse + tree, source_bytes = self._get_cached_tree(source) methods: list[JavaMethodNode] = [] self._walk_tree_for_methods( @@ -314,8 +319,7 @@ def find_classes(self, source: str) -> list[JavaClassNode]: List of JavaClassNode objects. """ - source_bytes = source.encode("utf8") - tree = self.parse(source_bytes) + tree, source_bytes = self._get_cached_tree(source) classes: list[JavaClassNode] = [] self._walk_tree_for_classes(tree.root_node, source_bytes, classes, is_inner=False) @@ -479,8 +483,7 @@ def find_fields(self, source: str, class_name: str | None = None) -> list[JavaFi List of JavaFieldInfo objects. """ - source_bytes = source.encode("utf8") - tree = self.parse(source_bytes) + tree, source_bytes = self._get_cached_tree(source) fields: list[JavaFieldInfo] = [] self._walk_tree_for_fields(tree.root_node, source_bytes, fields, current_class=None, target_class=class_name) @@ -678,6 +681,24 @@ def get_package_name(self, source: str) -> str | None: return None + def _get_cached_tree(self, source: str | bytes) -> tuple[Tree, bytes]: + """Return a cached parse tree for source or parse and cache it. + + This avoids reparsing identical source multiple times while the analyzer + instance lives, which is a major performance win when callers query + classes/fields/methods repeatedly for the same source. + """ + if isinstance(source, str): + source_bytes = source.encode("utf8") + else: + source_bytes = source + tree = self._tree_cache.get(source_bytes) + if tree is None: + tree = self.parse(source_bytes) + # Cache the tree for future calls + self._tree_cache[source_bytes] = tree + return tree, source_bytes + def get_java_analyzer() -> JavaAnalyzer: """Get a JavaAnalyzer instance. diff --git a/codeflash/languages/java/replacement.py b/codeflash/languages/java/replacement.py index 92ddd44e2..d561b097f 100644 --- a/codeflash/languages/java/replacement.py +++ b/codeflash/languages/java/replacement.py @@ -144,61 +144,63 @@ def _insert_class_members( class_indent = _get_indentation(lines[class_line]) if class_line < len(lines) else "" member_indent = class_indent + " " - result = source + result_bytes = source_bytes + + # Compute original body insertion points once + body_start = body_node.start_byte + body_end = body_node.end_byte + + # We'll keep track of a byte-offset delta as we modify the bytes so subsequent + # insertions can use adjusted positions without reparsing. + delta = 0 + + # Insert fields at the beginning of the class body (after opening brace) # Insert fields at the beginning of the class body (after opening brace) if fields: - # Re-parse to get current positions - classes = analyzer.find_classes(result) - for cls in classes: - if cls.name == class_name: - body_node = cls.node.child_by_field_name("body") - break + insert_point = body_start + 1 + delta # After opening brace + + # Build field text using list to avoid repeated string concatenation + field_parts: list[str] = [] + field_parts.append("\n") + for field in fields: + field_lines = field.strip().splitlines(keepends=True) + indented_field = _apply_indentation(field_lines, member_indent) + field_parts.append(indented_field) + if not indented_field.endswith("\n"): + field_parts.append("\n") + field_text = "".join(field_parts) + field_bytes = field_text.encode("utf8") - if body_node: - result_bytes = result.encode("utf8") - insert_point = body_node.start_byte + 1 # After opening brace + before = result_bytes[:insert_point] + after = result_bytes[insert_point:] + result_bytes = before + field_bytes + after - # Format fields - field_text = "\n" - for field in fields: - field_lines = field.strip().splitlines(keepends=True) - indented_field = _apply_indentation(field_lines, member_indent) - field_text += indented_field - if not indented_field.endswith("\n"): - field_text += "\n" + delta += len(field_bytes) # Adjust for next insertion(s) - before = result_bytes[:insert_point] - after = result_bytes[insert_point:] - result = (before + field_text.encode("utf8") + after).decode("utf8") + # Insert methods at the end of the class body (before closing brace) # Insert methods at the end of the class body (before closing brace) if methods: - # Re-parse to get current positions - classes = analyzer.find_classes(result) - for cls in classes: - if cls.name == class_name: - body_node = cls.node.child_by_field_name("body") - break - - if body_node: - result_bytes = result.encode("utf8") - insert_point = body_node.end_byte - 1 # Before closing brace + insert_point = body_end - 1 + delta # Before closing brace, adjust by delta - # Format methods - method_text = "\n" - for method in methods: - method_lines = method.strip().splitlines(keepends=True) - indented_method = _apply_indentation(method_lines, member_indent) - method_text += indented_method - if not indented_method.endswith("\n"): - method_text += "\n" - - before = result_bytes[:insert_point] - after = result_bytes[insert_point:] - result = (before + method_text.encode("utf8") + after).decode("utf8") - - return result + # Build method text efficiently + method_parts: list[str] = [] + method_parts.append("\n") + for method in methods: + method_lines = method.strip().splitlines(keepends=True) + indented_method = _apply_indentation(method_lines, member_indent) + method_parts.append(indented_method) + if not indented_method.endswith("\n"): + method_parts.append("\n") + method_text = "".join(method_parts) + method_bytes = method_text.encode("utf8") + + before = result_bytes[:insert_point] + after = result_bytes[insert_point:] + result_bytes = before + method_bytes + after + + return result_bytes.decode("utf8") def replace_function(