Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions codeflash/languages/java/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ def __init__(self) -> None:
"""Initialize the Java analyzer."""
self._parser: Parser | None = None

# Small cache mapping source bytes -> parsed Tree to avoid repeated parsing
# for identical source content. This helps when the same source is queried
# multiple times by callers (common in the codebase).
self._tree_cache: dict[bytes, Tree] = {}

@property
def parser(self) -> Parser:
"""Get the parser, creating it lazily."""
Expand Down Expand Up @@ -159,8 +164,8 @@ def find_methods(
List of JavaMethodNode objects describing found methods.

"""
source_bytes = source.encode("utf8")
tree = self.parse(source_bytes)
# Use cached parse tree when possible to avoid repeated expensive parse
tree, source_bytes = self._get_cached_tree(source)
methods: list[JavaMethodNode] = []

self._walk_tree_for_methods(
Expand Down Expand Up @@ -314,8 +319,7 @@ def find_classes(self, source: str) -> list[JavaClassNode]:
List of JavaClassNode objects.

"""
source_bytes = source.encode("utf8")
tree = self.parse(source_bytes)
tree, source_bytes = self._get_cached_tree(source)
classes: list[JavaClassNode] = []

self._walk_tree_for_classes(tree.root_node, source_bytes, classes, is_inner=False)
Expand Down Expand Up @@ -479,8 +483,7 @@ def find_fields(self, source: str, class_name: str | None = None) -> list[JavaFi
List of JavaFieldInfo objects.

"""
source_bytes = source.encode("utf8")
tree = self.parse(source_bytes)
tree, source_bytes = self._get_cached_tree(source)
fields: list[JavaFieldInfo] = []

self._walk_tree_for_fields(tree.root_node, source_bytes, fields, current_class=None, target_class=class_name)
Expand Down Expand Up @@ -678,6 +681,24 @@ def get_package_name(self, source: str) -> str | None:

return None

def _get_cached_tree(self, source: str | bytes) -> tuple[Tree, bytes]:
"""Return a cached parse tree for source or parse and cache it.

This avoids reparsing identical source multiple times while the analyzer
instance lives, which is a major performance win when callers query
classes/fields/methods repeatedly for the same source.
"""
if isinstance(source, str):
source_bytes = source.encode("utf8")
else:
source_bytes = source
tree = self._tree_cache.get(source_bytes)
if tree is None:
tree = self.parse(source_bytes)
# Cache the tree for future calls
self._tree_cache[source_bytes] = tree
return tree, source_bytes


def get_java_analyzer() -> JavaAnalyzer:
"""Get a JavaAnalyzer instance.
Expand Down
92 changes: 47 additions & 45 deletions codeflash/languages/java/replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,61 +144,63 @@ def _insert_class_members(
class_indent = _get_indentation(lines[class_line]) if class_line < len(lines) else ""
member_indent = class_indent + " "

result = source
result_bytes = source_bytes

# Compute original body insertion points once
body_start = body_node.start_byte
body_end = body_node.end_byte

# We'll keep track of a byte-offset delta as we modify the bytes so subsequent
# insertions can use adjusted positions without reparsing.
delta = 0

# Insert fields at the beginning of the class body (after opening brace)

# Insert fields at the beginning of the class body (after opening brace)
if fields:
# Re-parse to get current positions
classes = analyzer.find_classes(result)
for cls in classes:
if cls.name == class_name:
body_node = cls.node.child_by_field_name("body")
break
insert_point = body_start + 1 + delta # After opening brace

# Build field text using list to avoid repeated string concatenation
field_parts: list[str] = []
field_parts.append("\n")
for field in fields:
field_lines = field.strip().splitlines(keepends=True)
indented_field = _apply_indentation(field_lines, member_indent)
field_parts.append(indented_field)
if not indented_field.endswith("\n"):
field_parts.append("\n")
field_text = "".join(field_parts)
field_bytes = field_text.encode("utf8")

if body_node:
result_bytes = result.encode("utf8")
insert_point = body_node.start_byte + 1 # After opening brace
before = result_bytes[:insert_point]
after = result_bytes[insert_point:]
result_bytes = before + field_bytes + after

# Format fields
field_text = "\n"
for field in fields:
field_lines = field.strip().splitlines(keepends=True)
indented_field = _apply_indentation(field_lines, member_indent)
field_text += indented_field
if not indented_field.endswith("\n"):
field_text += "\n"
delta += len(field_bytes) # Adjust for next insertion(s)

before = result_bytes[:insert_point]
after = result_bytes[insert_point:]
result = (before + field_text.encode("utf8") + after).decode("utf8")
# Insert methods at the end of the class body (before closing brace)

# Insert methods at the end of the class body (before closing brace)
if methods:
# Re-parse to get current positions
classes = analyzer.find_classes(result)
for cls in classes:
if cls.name == class_name:
body_node = cls.node.child_by_field_name("body")
break

if body_node:
result_bytes = result.encode("utf8")
insert_point = body_node.end_byte - 1 # Before closing brace
insert_point = body_end - 1 + delta # Before closing brace, adjust by delta

# Format methods
method_text = "\n"
for method in methods:
method_lines = method.strip().splitlines(keepends=True)
indented_method = _apply_indentation(method_lines, member_indent)
method_text += indented_method
if not indented_method.endswith("\n"):
method_text += "\n"

before = result_bytes[:insert_point]
after = result_bytes[insert_point:]
result = (before + method_text.encode("utf8") + after).decode("utf8")

return result
# Build method text efficiently
method_parts: list[str] = []
method_parts.append("\n")
for method in methods:
method_lines = method.strip().splitlines(keepends=True)
indented_method = _apply_indentation(method_lines, member_indent)
method_parts.append(indented_method)
if not indented_method.endswith("\n"):
method_parts.append("\n")
method_text = "".join(method_parts)
method_bytes = method_text.encode("utf8")

before = result_bytes[:insert_point]
after = result_bytes[insert_point:]
result_bytes = before + method_bytes + after

return result_bytes.decode("utf8")


def replace_function(
Expand Down
Loading