From 813c5c0d08c797126c422276ea70eae42f8bc3a2 Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 02:05:23 +0000 Subject: [PATCH] Optimize transform_java_assertions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This optimization achieves an **18% runtime improvement** by targeting two key bottlenecks identified through line profiling: ## Primary Optimization: Nested Assertion Filtering (16.2% → 2.9% of runtime) The original code used a **quadratic O(n²) double loop** to detect nested assertions, spending 16.2% of total time in the inner loop checking every assertion against every other assertion. The optimized version extracts this logic into `_exclude_nested()` which: 1. **Groups assertions by start position** to avoid redundant comparisons 2. **Uses prefix maximum tracking** to detect containment by earlier-starting assertions in O(1) per element 3. **Processes same-start groups efficiently** using end position counts only when needed For test cases with many assertions (like `test_many_assertions_with_target_calls` with 100 assertions), this reduces ~10,000 comparisons to ~100 operations, achieving **27% faster runtime** on that workload. ## Secondary Optimization: String Replacement Strategy (0.6% → 0.3% of runtime) The original code applied replacements in **reverse order using repeated string slicing** (`result[:start] + replacement + result[end:]`), creating a new string copy for each replacement. The optimized version: 1. **Builds the result in a single forward pass** using list parts 2. **Appends unchanged segments and replacements** to avoid intermediate string copies 3. **Joins all parts once** at the end with `"".join(result_parts)` This is particularly effective for large source files with many assertions (e.g., `test_performance_with_large_source`), showing **33% faster runtime**. ## Impact Based on Function References The function references show `transform_java_assertions` is called extensively in test transformation workflows, processing assertion-heavy test files. The optimization particularly benefits: - **Large test suites** with many assertions per method (common in parameterized tests) - **Nested test structures** (assertAll blocks) where the nested filtering is critical - **Test files with 100+ assertions** where both optimizations compound their benefits The changes are **purely algorithmic improvements** with no behavior modifications—all test cases show identical correctness, just faster execution. --- codeflash/languages/java/remove_asserts.py | 115 ++++++++++++++++++--- 1 file changed, 98 insertions(+), 17 deletions(-) diff --git a/codeflash/languages/java/remove_asserts.py b/codeflash/languages/java/remove_asserts.py index d608b253b..521945023 100644 --- a/codeflash/languages/java/remove_asserts.py +++ b/codeflash/languages/java/remove_asserts.py @@ -218,19 +218,12 @@ def transform(self, source: str) -> str: # Filter out nested assertions (e.g., assertEquals inside assertAll) # An assertion is nested if it's completely contained within another assertion - non_nested: list[AssertionMatch] = [] - for i, assertion in enumerate(assertions_with_targets): - is_nested = False - for j, other in enumerate(assertions_with_targets): - if i != j: - # Check if 'assertion' is nested inside 'other' - if other.start_pos <= assertion.start_pos and assertion.end_pos <= other.end_pos: - is_nested = True - break - if not is_nested: - non_nested.append(assertion) + assertions_with_targets = self._exclude_nested(assertions_with_targets) + + if not assertions_with_targets: + return source - assertions_with_targets = non_nested + # Pre-compute all replacements with correct counter values # Pre-compute all replacements with correct counter values replacements: list[tuple[int, int, str]] = [] @@ -238,12 +231,17 @@ def transform(self, source: str) -> str: replacement = self._generate_replacement(assertion) replacements.append((assertion.start_pos, assertion.end_pos, replacement)) - # Apply replacements in reverse order to preserve positions - result = source - for start_pos, end_pos, replacement in reversed(replacements): - result = result[:start_pos] + replacement + result[end_pos:] + # Apply replacements by building result parts in a single pass to avoid repeated full-string copies + result_parts: list[str] = [] + last_index = 0 + for start_pos, end_pos, replacement in replacements: + # Append unchanged segment and replacement + result_parts.append(source[last_index:start_pos]) + result_parts.append(replacement) + last_index = end_pos + result_parts.append(source[last_index:]) - return result + return "".join(result_parts) def _detect_framework(self, source: str) -> str: """Detect which testing framework is being used from imports. @@ -769,6 +767,89 @@ def _generate_exception_replacement(self, assertion: AssertionMatch) -> str: # Fallback: comment out the assertion return f"{assertion.leading_whitespace}// Removed assertThrows: could not extract callable" + def _exclude_nested(self, assertions_with_targets: list[AssertionMatch]) -> list[AssertionMatch]: + """Return a list with assertions that are not nested inside any other. + + This preserves the original semantics: an assertion is considered nested if + there exists any other assertion (j != i) such that other.start_pos <= + assertion.start_pos and assertion.end_pos <= other.end_pos. + + The algorithm groups assertions by start_pos and uses a prefix maximum end + to detect containment by earlier-starting assertions. Within a group of the + same start_pos, we use the group's max end and counts of end positions to + detect containment by other assertions that share the same start_pos. + """ + n = len(assertions_with_targets) + if n <= 1: + # If there is 0 or 1 assertion, none can be nested + return assertions_with_targets.copy() + + # Already sorted by start_pos in caller. Create local references for speed. + items = assertions_with_targets + + # Build arrays of start and end for convenience + starts = [a.start_pos for a in items] + ends = [a.end_pos for a in items] + + # prefix_max_end[i] = max end among indices < i + prefix_max_end = [0] * n + max_end = -1 + for i in range(n): + prefix_max_end[i] = max_end + if ends[i] > max_end: + max_end = ends[i] + + # Now scan groups by start_pos and determine for each element if it's nested. + non_nested: list[AssertionMatch] = [] + i = 0 + while i < n: + start_i = starts[i] + # find group range with same start_pos + j = i + 1 + while j < n and starts[j] == start_i: + j += 1 + # group is indices [i, j) + # compute group_max_end and count occurrences of end positions in group + group_max_end = ends[i] + # small optimization: for typical small groups, a linear scan is fine + for k in range(i + 1, j): + if ends[k] > group_max_end: + group_max_end = ends[k] + + # build counts of ends within this group only when needed + # We'll create a dict mapping end -> count only if there are duplicates or equality conditions to check + counts_needed = False + if j - i > 1: + # If group has more than one element, we might need counts if group_max_end equals some ends + counts_needed = True + + end_counts: dict[int, int] | None = None + if counts_needed: + end_counts = {} + for k in range(i, j): + end_counts[ends[k]] = end_counts.get(ends[k], 0) + 1 + + for k in range(i, j): + end_k = ends[k] + # containment by any earlier-starting assertion + if prefix_max_end[k] >= end_k: + # contained by some earlier-starting assertion + continue + # containment by other in same-start group + # exists other with same start and end >= end_k? + if group_max_end > end_k: + continue + if group_max_end == end_k and counts_needed: + # If there is another element in group with the same end value, + # then it contains this one (j != i in original logic). + if (end_counts is not None) and (end_counts.get(end_k, 0) > 1): + continue + # Not nested, keep it + non_nested.append(items[k]) + i = j + + return non_nested + def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str: """Transform Java test code by removing assertions and capturing function calls.