Skip to content

⚡️ Speed up function transform_java_assertions by 19% in PR #1199 (omni-java)#1365

Open
codeflash-ai[bot] wants to merge 1 commit intoomni-javafrom
codeflash/optimize-pr1199-2026-02-04T02.05.19
Open

⚡️ Speed up function transform_java_assertions by 19% in PR #1199 (omni-java)#1365
codeflash-ai[bot] wants to merge 1 commit intoomni-javafrom
codeflash/optimize-pr1199-2026-02-04T02.05.19

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Feb 4, 2026

⚡️ This pull request contains optimizations for PR #1199

If you approve this dependent PR, these changes will be merged into the original PR branch omni-java.

This PR will be automatically closed if the original PR is merged.


📄 19% (0.19x) speedup for transform_java_assertions in codeflash/languages/java/remove_asserts.py

⏱️ Runtime : 27.3 milliseconds 23.0 milliseconds (best of 91 runs)

📝 Explanation and details

This optimization achieves an 18% runtime improvement by targeting two key bottlenecks identified through line profiling:

Primary Optimization: Nested Assertion Filtering (16.2% → 2.9% of runtime)

The original code used a quadratic O(n²) double loop to detect nested assertions, spending 16.2% of total time in the inner loop checking every assertion against every other assertion. The optimized version extracts this logic into _exclude_nested() which:

  1. Groups assertions by start position to avoid redundant comparisons
  2. Uses prefix maximum tracking to detect containment by earlier-starting assertions in O(1) per element
  3. Processes same-start groups efficiently using end position counts only when needed

For test cases with many assertions (like test_many_assertions_with_target_calls with 100 assertions), this reduces ~10,000 comparisons to ~100 operations, achieving 27% faster runtime on that workload.

Secondary Optimization: String Replacement Strategy (0.6% → 0.3% of runtime)

The original code applied replacements in reverse order using repeated string slicing (result[:start] + replacement + result[end:]), creating a new string copy for each replacement. The optimized version:

  1. Builds the result in a single forward pass using list parts
  2. Appends unchanged segments and replacements to avoid intermediate string copies
  3. Joins all parts once at the end with "".join(result_parts)

This is particularly effective for large source files with many assertions (e.g., test_performance_with_large_source), showing 33% faster runtime.

Impact Based on Function References

The function references show transform_java_assertions is called extensively in test transformation workflows, processing assertion-heavy test files. The optimization particularly benefits:

  • Large test suites with many assertions per method (common in parameterized tests)
  • Nested test structures (assertAll blocks) where the nested filtering is critical
  • Test files with 100+ assertions where both optimizations compound their benefits

The changes are purely algorithmic improvements with no behavior modifications—all test cases show identical correctness, just faster execution.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 46 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import pytest
from codeflash.languages.java.remove_asserts import transform_java_assertions

class TestTransformJavaAssertionsBasic:
    """Basic test cases for transform_java_assertions function."""

    def test_empty_source_code(self):
        """Test that empty source code is returned unchanged."""
        codeflash_output = transform_java_assertions("", "testMethod"); result = codeflash_output # 2.15μs -> 2.28μs (5.69% slower)

    def test_whitespace_only_source(self):
        """Test that whitespace-only source is returned unchanged."""
        codeflash_output = transform_java_assertions("   \n  \t  ", "testMethod"); result = codeflash_output # 2.37μs -> 2.29μs (3.44% faster)

    def test_source_without_assertions(self):
        """Test that source without any assertions is returned unchanged."""
        source = """
        public void testMethod() {
            int x = 5;
            String s = "hello";
        }
        """
        codeflash_output = transform_java_assertions(source, "testMethod"); result = codeflash_output # 84.6μs -> 84.0μs (0.669% faster)

    def test_simple_junit5_assert_equals(self):
        """Test transformation of simple JUnit 5 assertEquals assertion."""
        source = """
        import org.junit.jupiter.api.Test;
        import static org.junit.jupiter.api.Assertions.*;
        
        public class MyTest {
            public void testMethod() {
                assertEquals(5, getValue());
            }
        }
        """
        codeflash_output = transform_java_assertions(source, "testMethod"); result = codeflash_output # 178μs -> 179μs (0.234% slower)

    def test_simple_junit4_assert_equals(self):
        """Test transformation of JUnit 4 assertEquals assertion."""
        source = """
        import org.junit.Assert.*;
        
        public class MyTest {
            public void testMethod() {
                assertEquals(5, getValue());
            }
        }
        """
        codeflash_output = transform_java_assertions(source, "testMethod"); result = codeflash_output # 148μs -> 148μs (0.162% slower)

    def test_assertion_with_target_function_call(self):
        """Test that assertions containing target function calls are transformed."""
        source = """
        public void testAdd() {
            assertEquals(3, add(1, 2));
        }
        """
        codeflash_output = transform_java_assertions(source, "add"); result = codeflash_output # 83.1μs -> 81.9μs (1.39% faster)

    def test_assertion_without_target_function_call(self):
        """Test that assertions without target calls are handled."""
        source = """
        public void testSomething() {
            assertEquals(5, 5);
        }
        """
        codeflash_output = transform_java_assertions(source, "targetFunc"); result = codeflash_output # 69.2μs -> 68.3μs (1.34% faster)

    def test_multiple_assertions_with_target_calls(self):
        """Test transformation of multiple assertions with target calls."""
        source = """
        public void testMultiple() {
            assertEquals(1, getValue());
            assertEquals(2, getValue());
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 107μs -> 111μs (3.66% slower)

    def test_assertj_fluent_assertion(self):
        """Test transformation of AssertJ fluent assertion."""
        source = """
        import static org.assertj.core.api.Assertions.*;
        
        public void testMethod() {
            assertThat(getValue()).isEqualTo(5);
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 131μs -> 130μs (1.38% faster)

    def test_custom_qualified_name(self):
        """Test using qualified function name."""
        source = """
        public void test() {
            assertEquals(5, com.example.Math.add(2, 3));
        }
        """
        codeflash_output = transform_java_assertions(source, "add", qualified_name="com.example.Math.add"); result = codeflash_output # 101μs -> 100μs (0.837% faster)

class TestTransformJavaAssertionsEdgeCases:
    """Edge case test cases for transform_java_assertions function."""

    def test_assertion_with_chained_method_calls(self):
        """Test assertion with chained method calls in argument."""
        source = """
        public void test() {
            assertEquals(5, getValue().intValue());
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 83.8μs -> 83.2μs (0.711% faster)

    def test_assertion_with_nested_parentheses(self):
        """Test assertion with nested parentheses in target call."""
        source = """
        public void test() {
            assertEquals(10, add(multiply(2, 3), 4));
        }
        """
        codeflash_output = transform_java_assertions(source, "multiply"); result = codeflash_output # 86.8μs -> 85.9μs (1.05% faster)

    def test_assertion_with_string_literals_containing_quotes(self):
        """Test assertion with string literals that contain quotes."""
        source = '''
        public void test() {
            assertEquals("test\\"value", getValue());
        }
        '''
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 83.9μs -> 83.2μs (0.891% faster)

    def test_assertion_with_multiline_argument(self):
        """Test assertion with arguments spanning multiple lines."""
        source = """
        public void test() {
            assertEquals(
                5,
                getValue()
            );
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 134μs -> 133μs (0.542% faster)

    def test_nested_assertions_outer_removed(self):
        """Test that nested assertions are filtered out."""
        source = """
        public void test() {
            assertAll(
                () -> assertEquals(5, getValue()),
                () -> assertEquals(10, getOther())
            );
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 178μs -> 181μs (1.50% slower)
        # Nested assertions should be handled differently
        # The outer assertAll should not create a capture if it's truly nested

    def test_assertion_with_boolean_value(self):
        """Test assertTrue and assertFalse assertions."""
        source = """
        public void test() {
            assertTrue(isValid());
        }
        """
        codeflash_output = transform_java_assertions(source, "isValid"); result = codeflash_output # 76.1μs -> 75.3μs (1.14% faster)

    def test_assertion_with_null_check(self):
        """Test assertNull and assertNotNull."""
        source = """
        public void test() {
            assertNull(getValue());
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 74.9μs -> 74.5μs (0.497% faster)

    def test_assertion_with_exception_expected(self):
        """Test assertThrows or exception assertion."""
        source = """
        import org.junit.jupiter.api.Test;
        
        public void test() {
            assertThrows(IllegalArgumentException.class, () -> getValue());
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 142μs -> 142μs (0.007% faster)

    def test_assertion_with_lambda_expression(self):
        """Test assertion containing lambda expression."""
        source = """
        public void test() {
            assertAll(
                () -> assertEquals(5, getValue())
            );
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 134μs -> 137μs (2.11% slower)

    def test_assertion_with_method_reference(self):
        """Test assertion with method reference syntax."""
        source = """
        public void test() {
            assertEquals(5, this::getValue);
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 73.1μs -> 72.4μs (1.02% faster)

    def test_source_with_comments(self):
        """Test that comments don't break assertion detection."""
        source = """
        public void test() {
            // This is a comment
            assertEquals(5, getValue()); // inline comment
            /* block comment */ assertEquals(10, getOther());
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 129μs -> 130μs (0.201% slower)

    def test_assertion_in_conditional(self):
        """Test assertion inside an if statement."""
        source = """
        public void test() {
            if (condition) {
                assertEquals(5, getValue());
            }
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 111μs -> 111μs (0.748% faster)

    def test_assertion_in_loop(self):
        """Test assertion inside a loop."""
        source = """
        public void test() {
            for (int i = 0; i < 10; i++) {
                assertEquals(i, getValue(i));
            }
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 126μs -> 124μs (1.13% faster)

    def test_assertion_with_no_arguments(self):
        """Test assertion with no arguments after equals sign."""
        source = """
        public void test() {
            assertEquals(getValue());
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 75.4μs -> 75.0μs (0.481% faster)

    def test_assertion_method_with_different_names(self):
        """Test various assertion method names."""
        for assert_method in ["assertEquals", "assertTrue", "assertFalse", "assertNotNull", "assertNull"]:
            source = f"""
            public void test() {{
                {assert_method}(getValue());
            }}
            """
            codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 404μs -> 401μs (0.826% faster)

    def test_hamcrest_assertion(self):
        """Test Hamcrest matchers assertion."""
        source = """
        import static org.hamcrest.MatcherAssert.assertThat;
        import static org.hamcrest.Matchers.*;
        
        public void test() {
            assertThat(getValue(), is(5));
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 184μs -> 181μs (1.31% faster)

    def test_testng_assertion(self):
        """Test TestNG assertion."""
        source = """
        import org.testng.Assert;
        
        public void test() {
            Assert.assertEquals(getValue(), 5);
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 122μs -> 121μs (0.074% faster)

    def test_assertion_with_complex_expression(self):
        """Test assertion with complex expression in argument."""
        source = """
        public void test() {
            assertEquals(5 + 3 * 2, calculate(getValue()));
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 89.9μs -> 89.5μs (0.413% faster)

    def test_assertion_preserves_leading_whitespace(self):
        """Test that leading whitespace is preserved."""
        source = """
        public void test() {
            if (x > 0) {
                    assertEquals(5, getValue());
            }
        }
        """
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 116μs -> 116μs (0.205% faster)

class TestTransformJavaAssertionsLargeScale:
    """Large scale test cases for transform_java_assertions function."""

    def test_many_assertions_with_target_calls(self):
        """Test handling of many assertions with target function calls."""
        # Create source with 100 assertions containing target calls
        lines = ["public void test() {"]
        for i in range(100):
            lines.append(f"    assertEquals({i}, getValue({i}));")
        lines.append("}")
        source = "\n".join(lines)
        
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 2.56ms -> 2.01ms (27.0% faster)

    def test_large_method_with_mixed_content(self):
        """Test large method with mix of assertions, code, and comments."""
        lines = ["public void testLarge() {"]
        for i in range(150):
            if i % 3 == 0:
                lines.append(f"    assertEquals({i}, getValue({i}));  // Check {i}")
            elif i % 3 == 1:
                lines.append(f"    int x{i} = processData({i});")
            else:
                lines.append(f"    // Comment line {i}")
        lines.append("}")
        source = "\n".join(lines)
        
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 2.18ms -> 2.04ms (6.74% faster)

    def test_many_different_target_functions(self):
        """Test assertions calling many different target functions."""
        lines = ["public void test() {"]
        for i in range(100):
            func_name = f"getValue{i}"
            lines.append(f"    assertEquals({i}, {func_name}());")
        lines.append("}")
        source = "\n".join(lines)
        
        # Test with one specific function
        codeflash_output = transform_java_assertions(source, "getValue0"); result = codeflash_output # 1.61ms -> 1.60ms (0.148% faster)

    def test_deeply_nested_method_calls(self):
        """Test assertion with deeply nested method calls."""
        nested_call = "getValue()"
        for i in range(50):
            nested_call = f"transform({nested_call})"
        
        source = f"""
        public void test() {{
            assertEquals(5, {nested_call});
        }}
        """
        
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 308μs -> 308μs (0.146% slower)

    def test_large_assertion_with_many_arguments(self):
        """Test assertion with many arguments."""
        args = ", ".join([f"arg{i}" for i in range(50)])
        source = f"""
        public void test() {{
            assertEquals(5, complexMethod({args}));
        }}
        """
        
        codeflash_output = transform_java_assertions(source, "complexMethod"); result = codeflash_output # 299μs -> 298μs (0.201% faster)

    def test_multiple_classes_with_assertions(self):
        """Test source with multiple class definitions."""
        source = """
        public class TestClass1 {
            public void test1() {
                assertEquals(5, getValue());
            }
            public void test2() {
                assertEquals(10, getValue());
            }
        }
        
        public class TestClass2 {
            public void test3() {
                assertEquals(15, getValue());
            }
        }
        """
        
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 276μs -> 277μs (0.538% slower)

    def test_assertions_with_string_array_arguments(self):
        """Test assertion with array literal arguments."""
        source = """
        public void test() {
            assertEquals(new int[]{1, 2, 3}, getArray());
        }
        """
        
        codeflash_output = transform_java_assertions(source, "getArray"); result = codeflash_output # 92.2μs -> 91.7μs (0.601% faster)

    def test_assertion_counter_increments_correctly(self):
        """Test that internal counter increments and creates unique variable names."""
        source = """
        public void test() {
            assertEquals(1, getValue());
            assertEquals(2, getValue());
            assertEquals(3, getValue());
        }
        """
        
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 129μs -> 132μs (2.25% slower)

    def test_performance_with_large_source(self):
        """Test performance with large source code file."""
        lines = ["public class LargeTest {"]
        
        # Add 500 lines of various code
        for i in range(500):
            if i % 10 == 0:
                lines.append(f"    public void test{i}() {{")
                for j in range(5):
                    lines.append(f"        assertEquals({i * j}, getValue({i}, {j}));")
                lines.append("    }")
            else:
                lines.append(f"    int field{i} = {i};")
        
        lines.append("}")
        source = "\n".join(lines)
        
        # This should complete without timeout
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 14.4ms -> 10.8ms (33.4% faster)

    def test_assertion_positions_preserved_with_many_lines(self):
        """Test that replacement positions are correct with many preceding lines."""
        lines = ["public void test() {"]
        # Add 200 lines of filler code
        for i in range(200):
            lines.append(f"    int x{i} = {i};")
        # Add assertions
        for i in range(10):
            lines.append(f"    assertEquals({i}, getValue({i}));")
        lines.append("}")
        source = "\n".join(lines)
        
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 1.75ms -> 1.74ms (0.558% faster)

    def test_empty_lines_and_formatting_preserved(self):
        """Test that empty lines and formatting are mostly preserved."""
        source = """
        public void test() {
        
            assertEquals(5, getValue());
            
            
            assertEquals(10, getValue());
            
        }
        """
        
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 181μs -> 184μs (1.62% slower)

    def test_boundary_assertion_at_start(self):
        """Test assertion as first statement in method."""
        source = """
        public void test() {
            assertEquals(5, getValue());
        }
        """
        
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 80.0μs -> 79.3μs (0.896% faster)

    def test_boundary_assertion_at_end(self):
        """Test assertion as last statement in method."""
        source = """
        public void test() {
            int x = 5;
            String s = "test";
            assertEquals(5, getValue());
        }
        """
        
        codeflash_output = transform_java_assertions(source, "getValue"); result = codeflash_output # 116μs -> 118μs (1.16% slower)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr1199-2026-02-04T02.05.19 and push.

Codeflash Static Badge

This optimization achieves an **18% runtime improvement** by targeting two key bottlenecks identified through line profiling:

## Primary Optimization: Nested Assertion Filtering (16.2% → 2.9% of runtime)

The original code used a **quadratic O(n²) double loop** to detect nested assertions, spending 16.2% of total time in the inner loop checking every assertion against every other assertion. The optimized version extracts this logic into `_exclude_nested()` which:

1. **Groups assertions by start position** to avoid redundant comparisons
2. **Uses prefix maximum tracking** to detect containment by earlier-starting assertions in O(1) per element
3. **Processes same-start groups efficiently** using end position counts only when needed

For test cases with many assertions (like `test_many_assertions_with_target_calls` with 100 assertions), this reduces ~10,000 comparisons to ~100 operations, achieving **27% faster runtime** on that workload.

## Secondary Optimization: String Replacement Strategy (0.6% → 0.3% of runtime)

The original code applied replacements in **reverse order using repeated string slicing** (`result[:start] + replacement + result[end:]`), creating a new string copy for each replacement. The optimized version:

1. **Builds the result in a single forward pass** using list parts
2. **Appends unchanged segments and replacements** to avoid intermediate string copies
3. **Joins all parts once** at the end with `"".join(result_parts)`

This is particularly effective for large source files with many assertions (e.g., `test_performance_with_large_source`), showing **33% faster runtime**.

## Impact Based on Function References

The function references show `transform_java_assertions` is called extensively in test transformation workflows, processing assertion-heavy test files. The optimization particularly benefits:

- **Large test suites** with many assertions per method (common in parameterized tests)
- **Nested test structures** (assertAll blocks) where the nested filtering is critical
- **Test files with 100+ assertions** where both optimizations compound their benefits

The changes are **purely algorithmic improvements** with no behavior modifications—all test cases show identical correctness, just faster execution.
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Feb 4, 2026
@codeflash-ai codeflash-ai bot mentioned this pull request Feb 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants