diff --git a/codeflash/languages/java/remove_asserts.py b/codeflash/languages/java/remove_asserts.py index d608b253b..521945023 100644 --- a/codeflash/languages/java/remove_asserts.py +++ b/codeflash/languages/java/remove_asserts.py @@ -218,19 +218,12 @@ def transform(self, source: str) -> str: # Filter out nested assertions (e.g., assertEquals inside assertAll) # An assertion is nested if it's completely contained within another assertion - non_nested: list[AssertionMatch] = [] - for i, assertion in enumerate(assertions_with_targets): - is_nested = False - for j, other in enumerate(assertions_with_targets): - if i != j: - # Check if 'assertion' is nested inside 'other' - if other.start_pos <= assertion.start_pos and assertion.end_pos <= other.end_pos: - is_nested = True - break - if not is_nested: - non_nested.append(assertion) + assertions_with_targets = self._exclude_nested(assertions_with_targets) + + if not assertions_with_targets: + return source - assertions_with_targets = non_nested + # Pre-compute all replacements with correct counter values # Pre-compute all replacements with correct counter values replacements: list[tuple[int, int, str]] = [] @@ -238,12 +231,17 @@ def transform(self, source: str) -> str: replacement = self._generate_replacement(assertion) replacements.append((assertion.start_pos, assertion.end_pos, replacement)) - # Apply replacements in reverse order to preserve positions - result = source - for start_pos, end_pos, replacement in reversed(replacements): - result = result[:start_pos] + replacement + result[end_pos:] + # Apply replacements by building result parts in a single pass to avoid repeated full-string copies + result_parts: list[str] = [] + last_index = 0 + for start_pos, end_pos, replacement in replacements: + # Append unchanged segment and replacement + result_parts.append(source[last_index:start_pos]) + result_parts.append(replacement) + last_index = end_pos + result_parts.append(source[last_index:]) - return result + return "".join(result_parts) def _detect_framework(self, source: str) -> str: """Detect which testing framework is being used from imports. @@ -769,6 +767,89 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str: # Fallback: comment out the assertion return f"{assertion.leading_whitespace}// Removed assertThrows: could not extract callable" + def _exclude_nested(self, assertions_with_targets: list[AssertionMatch]) -> list[AssertionMatch]: + """Return a list with assertions that are not nested inside any other. + + This preserves the original semantics: an assertion is considered nested if + there exists any other assertion (j != i) such that other.start_pos <= + assertion.start_pos and assertion.end_pos <= other.end_pos. + + The algorithm groups assertions by start_pos and uses a prefix maximum end + to detect containment by earlier-starting assertions. Within a group of the + same start_pos, we use the group's max end and counts of end positions to + detect containment by other assertions that share the same start_pos. + """ + n = len(assertions_with_targets) + if n <= 1: + # If there is 0 or 1 assertion, none can be nested + return assertions_with_targets.copy() + + # Already sorted by start_pos in caller. Create local references for speed. + items = assertions_with_targets + + # Build arrays of start and end for convenience + starts = [a.start_pos for a in items] + ends = [a.end_pos for a in items] + + # prefix_max_end[i] = max end among indices < i + prefix_max_end = [0] * n + max_end = -1 + for i in range(n): + prefix_max_end[i] = max_end + if ends[i] > max_end: + max_end = ends[i] + + # Now scan groups by start_pos and determine for each element if it's nested. + non_nested: list[AssertionMatch] = [] + i = 0 + while i < n: + start_i = starts[i] + # find group range with same start_pos + j = i + 1 + while j < n and starts[j] == start_i: + j += 1 + # group is indices [i, j) + # compute group_max_end and count occurrences of end positions in group + group_max_end = ends[i] + # small optimization: for typical small groups, a linear scan is fine + for k in range(i + 1, j): + if ends[k] > group_max_end: + group_max_end = ends[k] + + # build counts of ends within this group only when needed + # We'll create a dict mapping end -> count only if there are duplicates or equality conditions to check + counts_needed = False + if j - i > 1: + # If group has more than one element, we might need counts if group_max_end equals some ends + counts_needed = True + + end_counts: dict[int, int] | None = None + if counts_needed: + end_counts = {} + for k in range(i, j): + end_counts[ends[k]] = end_counts.get(ends[k], 0) + 1 + + for k in range(i, j): + end_k = ends[k] + # containment by any earlier-starting assertion + if prefix_max_end[k] >= end_k: + # contained by some earlier-starting assertion + continue + # containment by other in same-start group + # exists other with same start and end >= end_k? + if group_max_end > end_k: + continue + if group_max_end == end_k and counts_needed: + # If there is another element in group with the same end value, + # then it contains this one (j != i in original logic). + if (end_counts is not None) and (end_counts.get(end_k, 0) > 1): + continue + # Not nested, keep it + non_nested.append(items[k]) + i = j + + return non_nested + def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str: """Transform Java test code by removing assertions and capturing function calls.