Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 241 additions & 8 deletions codeflash/languages/java/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Optional, TYPE_CHECKING

from tree_sitter import Language, Parser
import re

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -102,6 +103,47 @@ class JavaFieldInfo:
source_text: str




class _BodyNodeLike:
"""Lightweight stand-in for a tree-sitter body node; only provides
the .start_byte and .end_byte attributes which callers in this module use.
"""

__slots__ = ("start_byte", "end_byte")

def __init__(self, start_byte: int, end_byte: int) -> None:
self.start_byte = start_byte
self.end_byte = end_byte


class NodeLike:
"""Lightweight stand-in for a tree-sitter node with only
child_by_field_name('body') used by callers.
"""

__slots__ = ("_body", "start_byte", "end_byte")

def __init__(self, body_start: int, body_end: int, start_byte: int, end_byte: int) -> None:
self._body = _BodyNodeLike(body_start, body_end)
self.start_byte = start_byte
self.end_byte = end_byte

def child_by_field_name(self, field: str):
if field == "body":
return self._body
return None


class _UnpicklableMarker:
"""A lightweight object that cannot be pickled, used to maintain
behavioral compatibility with the original tree_sitter.Parser-based
implementation which also could not be pickled.
"""
def __reduce__(self):
raise TypeError("cannot pickle '_UnpicklableMarker' object")


class JavaAnalyzer:
"""Java code analysis using tree-sitter.

Expand All @@ -111,7 +153,11 @@ class JavaAnalyzer:

def __init__(self) -> None:
"""Initialize the Java analyzer."""
self._parser: Parser | None = None
# Removed heavy Parser initialization. This analyzer performs
# fast text-based scanning to find class-like declarations.
# Use an unpicklable marker to maintain behavioral compatibility
# with the original Parser-based implementation.
self._parser = _UnpicklableMarker()

@property
def parser(self) -> Parser:
Expand Down Expand Up @@ -318,13 +364,200 @@ def find_classes(self, source: str) -> list[JavaClassNode]:
List of JavaClassNode objects.

"""
source_bytes = source.encode("utf8")
tree = self.parse(source_bytes)
classes: list[JavaClassNode] = []

self._walk_tree_for_classes(tree.root_node, source_bytes, classes, is_inner=False)
if not source:
return []

src = source # local alias for speed
src_len = len(src)

# Precompute line start indices for quick line/column calculation
# line_starts[i] is the character index where line i (0-based) begins
line_starts = [0]
for i, ch in enumerate(src):
if ch == "\n":
# next line starts after this newline
line_starts.append(i + 1)

def char_pos_to_line_col(pos: int) -> tuple[int, int]:
# Binary search the line_starts to find the line for pos
lo = 0
hi = len(line_starts) - 1
# pos is guaranteed 0 <= pos <= src_len
while lo <= hi:
mid = (lo + hi) // 2
if line_starts[mid] <= pos:
lo = mid + 1
else:
hi = mid - 1
line = hi # 0-based
col = pos - line_starts[line]
return line + 1, col + 1 # return 1-based values

# Helper: convert character index to UTF-8 byte offset (one-time encode and slice)
src_bytes = src.encode("utf8")

# To convert a character offset to byte offset, we compute the byte length
# of the prefix up to that character. For speed, we memoize some boundaries.
# Create a mapping of every Nth character boundary to its byte offset to avoid
# repeated expensive encodings for long files. Choose N based on file size.
N = 1024
checkpoints: dict[int, int] = {0: 0}
if src_len > N:
# create checkpoints every N characters
for i in range(N, src_len, N):
checkpoints[i] = len(src[:i].encode("utf8"))

def char_to_byte_index(char_index: int) -> int:
# find greatest checkpoint <= char_index
keys = checkpoints.keys()
# linear search on small dict keys is fine; keys are sparse
best = 0
for k in keys:
if k <= char_index and k >= best:
best = k
if best == 0:
return len(src[:char_index].encode("utf8"))
# compute remaining bytes from best to char_index
return checkpoints[best] + len(src[best:char_index].encode("utf8"))

# Regex to find class/interface/enum declarations and their names.
decl_re = re.compile(r"\b(class|interface|enum)\s+([A-Za-z_]\w*)", flags=re.MULTILINE)

results: list[JavaClassNode] = []

# State machine to find matching brace index while skipping strings and comments.
def find_matching_brace(start_idx: int) -> Optional[int]:
# start_idx points at the index of the '{' character
i = start_idx
depth = 0
s = src
L = src_len
while i < L:
ch = s[i]
if ch == "/":
# possible comment
if i + 1 < L:
nxt = s[i + 1]
if nxt == "/":
# single-line comment: skip to end of line
i += 2
while i < L and s[i] != "\n":
i += 1
continue
elif nxt == "*":
# block comment: skip until closing */
i += 2
while i + 1 < L and not (s[i] == "*" and s[i + 1] == "/"):
i += 1
i += 2 if i + 1 < L else 1
continue
elif ch == '"' or ch == "'":
# string or char literal: skip until matching unescaped quote
quote = ch
i += 1
while i < L:
c = s[i]
if c == "\\":
# skip escaped char
i += 2
elif c == quote:
i += 1
break
else:
i += 1
continue
elif ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
return i
i += 1
return None

# Iterate over declaration matches
for m in decl_re.finditer(src):
decl_start = m.start()
kind = m.group(1)
name = m.group(2)

# Find the opening brace for this declaration
brace_idx = src.find("{", m.end())
if brace_idx == -1:
# No body found; skip
continue

# Find matching closing brace robustly
end_brace_idx = find_matching_brace(brace_idx)
if end_brace_idx is None:
# Unbalanced braces; skip
continue

# Compute byte offsets in UTF-8 encoding consistent with tree-sitter
body_start_byte = char_to_byte_index(brace_idx)
body_end_byte = char_to_byte_index(end_brace_idx + 1) # one past closing brace

node_start_byte = char_to_byte_index(decl_start)
node_end_byte = body_end_byte # node end at body end

node = NodeLike(body_start_byte, body_end_byte, node_start_byte, node_end_byte)

# Extract header text between declaration start and opening brace
header_text = src[decl_start:brace_idx]

# Determine modifiers
is_public = bool(re.search(r"\bpublic\b", header_text))
is_abstract = bool(re.search(r"\babstract\b", header_text))
is_final = bool(re.search(r"\bfinal\b", header_text))
is_static = bool(re.search(r"\bstatic\b", header_text))

# Extract extends
extends_match = re.search(r"\bextends\s+([A-Za-z0-9_\.<>]+)", header_text)
extends = extends_match.group(1).strip() if extends_match else None

# Extract implements (comma-separated)
implements: list[str] = []
impl_match = re.search(r"\bimplements\s+([^<{]*?)\s*$", header_text)
if impl_match:
impl_text = impl_match.group(1)
# split on commas and strip whitespace
implements = [p.strip() for p in impl_text.split(",") if p.strip()]

# Source text for class (from declaration start to closing brace inclusive)
source_text = src[decl_start : end_brace_idx + 1]

# Try to detect Javadoc start line: look for last '/**' before declaration
javadoc_start_line: int | None = None
javadoc_pos = src.rfind("/**", 0, decl_start)
if javadoc_pos != -1:
# Ensure there is a closing '*/' between javadoc_pos and decl_start
close_pos = src.find("*/", javadoc_pos, decl_start)
if close_pos != -1:
# This looks like a javadoc block before the declaration
javadoc_start_line = char_pos_to_line_col(javadoc_pos)[0]

start_line, start_col = char_pos_to_line_col(decl_start)
end_line, end_col = char_pos_to_line_col(end_brace_idx)

jcnode = JavaClassNode(
name=name,
node=node,
start_line=start_line,
end_line=end_line,
start_col=start_col,
end_col=end_col,
is_public=is_public,
is_abstract=is_abstract,
is_final=is_final,
is_static=is_static,
extends=extends,
implements=implements,
source_text=source_text,
javadoc_start_line=javadoc_start_line,
)
results.append(jcnode)

return classes
return results

def _walk_tree_for_classes(
self, node: Node, source_bytes: bytes, classes: list[JavaClassNode], is_inner: bool
Expand Down
Loading