Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 98 additions & 17 deletions codeflash/languages/java/remove_asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,32 +218,30 @@ def transform(self, source: str) -> str:

# Filter out nested assertions (e.g., assertEquals inside assertAll)
# An assertion is nested if it's completely contained within another assertion
non_nested: list[AssertionMatch] = []
for i, assertion in enumerate(assertions_with_targets):
is_nested = False
for j, other in enumerate(assertions_with_targets):
if i != j:
# Check if 'assertion' is nested inside 'other'
if other.start_pos <= assertion.start_pos and assertion.end_pos <= other.end_pos:
is_nested = True
break
if not is_nested:
non_nested.append(assertion)
assertions_with_targets = self._exclude_nested(assertions_with_targets)

if not assertions_with_targets:
return source

assertions_with_targets = non_nested
# Pre-compute all replacements with correct counter values

# Pre-compute all replacements with correct counter values
replacements: list[tuple[int, int, str]] = []
for assertion in assertions_with_targets:
replacement = self._generate_replacement(assertion)
replacements.append((assertion.start_pos, assertion.end_pos, replacement))

# Apply replacements in reverse order to preserve positions
result = source
for start_pos, end_pos, replacement in reversed(replacements):
result = result[:start_pos] + replacement + result[end_pos:]
# Apply replacements by building result parts in a single pass to avoid repeated full-string copies
result_parts: list[str] = []
last_index = 0
for start_pos, end_pos, replacement in replacements:
# Append unchanged segment and replacement
result_parts.append(source[last_index:start_pos])
result_parts.append(replacement)
last_index = end_pos
result_parts.append(source[last_index:])

return result
return "".join(result_parts)

def _detect_framework(self, source: str) -> str:
"""Detect which testing framework is being used from imports.
Expand Down Expand Up @@ -769,6 +767,89 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
# Fallback: comment out the assertion
return f"{assertion.leading_whitespace}// Removed assertThrows: could not extract callable"

def _exclude_nested(self, assertions_with_targets: list[AssertionMatch]) -> list[AssertionMatch]:
"""Return a list with assertions that are not nested inside any other.

This preserves the original semantics: an assertion is considered nested if
there exists any other assertion (j != i) such that other.start_pos <=
assertion.start_pos and assertion.end_pos <= other.end_pos.

The algorithm groups assertions by start_pos and uses a prefix maximum end
to detect containment by earlier-starting assertions. Within a group of the
same start_pos, we use the group's max end and counts of end positions to
detect containment by other assertions that share the same start_pos.
"""
n = len(assertions_with_targets)
if n <= 1:
# If there is 0 or 1 assertion, none can be nested
return assertions_with_targets.copy()

# Already sorted by start_pos in caller. Create local references for speed.
items = assertions_with_targets

# Build arrays of start and end for convenience
starts = [a.start_pos for a in items]
ends = [a.end_pos for a in items]

# prefix_max_end[i] = max end among indices < i
prefix_max_end = [0] * n
max_end = -1
for i in range(n):
prefix_max_end[i] = max_end
if ends[i] > max_end:
max_end = ends[i]

# Now scan groups by start_pos and determine for each element if it's nested.
non_nested: list[AssertionMatch] = []
i = 0
while i < n:
start_i = starts[i]
# find group range with same start_pos
j = i + 1
while j < n and starts[j] == start_i:
j += 1
# group is indices [i, j)
# compute group_max_end and count occurrences of end positions in group
group_max_end = ends[i]
# small optimization: for typical small groups, a linear scan is fine
for k in range(i + 1, j):
if ends[k] > group_max_end:
group_max_end = ends[k]

# build counts of ends within this group only when needed
# We'll create a dict mapping end -> count only if there are duplicates or equality conditions to check
counts_needed = False
if j - i > 1:
# If group has more than one element, we might need counts if group_max_end equals some ends
counts_needed = True

end_counts: dict[int, int] | None = None
if counts_needed:
end_counts = {}
for k in range(i, j):
end_counts[ends[k]] = end_counts.get(ends[k], 0) + 1

for k in range(i, j):
end_k = ends[k]
# containment by any earlier-starting assertion
if prefix_max_end[k] >= end_k:
# contained by some earlier-starting assertion
continue
# containment by other in same-start group
# exists other with same start and end >= end_k?
if group_max_end > end_k:
continue
if group_max_end == end_k and counts_needed:
# If there is another element in group with the same end value,
# then it contains this one (j != i in original logic).
if (end_counts is not None) and (end_counts.get(end_k, 0) > 1):
continue
# Not nested, keep it
non_nested.append(items[k])
i = j

return non_nested


def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str:
"""Transform Java test code by removing assertions and capturing function calls.
Expand Down
Loading