Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 40 additions & 26 deletions codeflash/languages/java/remove_asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,13 @@ def __init__(
self.invocation_counter = 0
self._detected_framework: str | None = None


# Pre-compile regex patterns to avoid recompilation
self._hamcrest_pattern = re.compile(r"(\s*)((?:MatcherAssert\.)?assertThat)\s*\(", re.MULTILINE)
self._method_pattern = re.compile(rf"({re.escape(self.func_name)})\s*\(", re.MULTILINE)
self._new_class_pattern = re.compile(r"new\s+[a-zA-Z_]\w*\s*$")
self._ident_pattern = re.compile(r"[a-zA-Z_]\w*(?:\.[a-zA-Z_]\w*)*\s*$")

def transform(self, source: str) -> str:
"""Remove assertions from source code, preserving target function calls.

Expand Down Expand Up @@ -417,9 +424,7 @@ def _find_hamcrest_assertions(self, source: str) -> list[AssertionMatch]:
return assertions

# Pattern for Hamcrest: assertThat(actual, is(...)) or assertThat(reason, actual, matcher)
pattern = re.compile(r"(\s*)((?:MatcherAssert\.)?assertThat)\s*\(", re.MULTILINE)

for match in pattern.finditer(source):
for match in self._hamcrest_pattern.finditer(source):
leading_ws = match.group(1)
start_pos = match.start()
paren_start = match.end() - 1
Expand Down Expand Up @@ -510,9 +515,7 @@ def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCa
# - method(args) (no receiver)
#
# Strategy: Find the function name, then look backwards for the receiver
pattern = re.compile(rf"({re.escape(self.func_name)})\s*\(", re.MULTILINE)

for match in pattern.finditer(content):
for match in self._method_pattern.finditer(content):
method_name = match.group(1)
method_start = match.start()

Expand All @@ -530,26 +533,27 @@ def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCa
stripped_before = before_method.rstrip()
if stripped_before.endswith("."):
dot_pos = len(stripped_before) - 1
before_dot = content[:dot_pos]
before_dot_content = content[:dot_pos]

# Check for new ClassName() or new ClassName(args)
stripped_before_dot = before_dot.rstrip()
stripped_before_dot = before_dot_content.rstrip()
if stripped_before_dot.endswith(")"):
# Find matching opening paren for constructor args
close_paren_pos = len(stripped_before_dot) - 1
paren_depth = 1
i = close_paren_pos - 1
while i >= 0 and paren_depth > 0:
if stripped_before_dot[i] == ")":
char = stripped_before_dot[i]
if char == ")":
paren_depth += 1
elif stripped_before_dot[i] == "(":
elif char == "(":
paren_depth -= 1
i -= 1
if paren_depth == 0:
open_paren_pos = i + 1
# Look for "new ClassName" before the opening paren
before_paren = stripped_before_dot[:open_paren_pos].rstrip()
new_match = re.search(r"new\s+[a-zA-Z_]\w*\s*$", before_paren)
new_match = self._new_class_pattern.search(before_paren)
if new_match:
receiver_start = new_match.start()
else:
Expand All @@ -558,14 +562,20 @@ def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCa
receiver_start = open_paren_pos
else:
# Simple identifier: obj.method() or Class.method() or pkg.Class.method()
ident_match = re.search(r"[a-zA-Z_]\w*(?:\.[a-zA-Z_]\w*)*\s*$", stripped_before_dot)
ident_match = self._ident_pattern.search(stripped_before_dot)
if ident_match:
receiver_start = ident_match.start()

full_call = content[receiver_start:end_pos]
receiver = (
content[receiver_start:method_start].rstrip(".").strip() if receiver_start < method_start else None
)

# Extract receiver more efficiently
if receiver_start < method_start:
receiver_text = content[receiver_start:method_start]
# Remove trailing dot and whitespace in one pass
receiver = receiver_text.rstrip().rstrip(".")
else:
receiver = None


target_calls.append(
TargetCall(
Expand Down Expand Up @@ -637,22 +647,26 @@ def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | N
in_string = False
string_char = None
in_char = False
code_len = len(code)

while pos < len(code) and depth > 0:
while pos < code_len and depth > 0:
char = code[pos]
prev_char = code[pos - 1] if pos > 0 else ""

# Handle character literals
if char == "'" and not in_string and prev_char != "\\":
in_char = not in_char
if char == "'" and not in_string:
# Check for escape by looking back
if pos == 0 or code[pos - 1] != "\\":
in_char = not in_char
# Handle string literals (double quotes)
elif char == '"' and not in_char and prev_char != "\\":
if not in_string:
in_string = True
string_char = char
elif char == string_char:
in_string = False
string_char = None
elif char == '"' and not in_char:
# Check for escape by looking back
if pos == 0 or code[pos - 1] != "\\":
if not in_string:
in_string = True
string_char = char
elif char == string_char:
in_string = False
string_char = None
elif not in_string and not in_char:
if char == "(":
depth += 1
Expand Down
Loading