diff --git a/codeflash/languages/java/remove_asserts.py b/codeflash/languages/java/remove_asserts.py index d608b253b..f530bf067 100644 --- a/codeflash/languages/java/remove_asserts.py +++ b/codeflash/languages/java/remove_asserts.py @@ -185,6 +185,13 @@ def __init__( self.invocation_counter = 0 self._detected_framework: str | None = None + + # Pre-compile regex patterns to avoid recompilation + self._hamcrest_pattern = re.compile(r"(\s*)((?:MatcherAssert\.)?assertThat)\s*\(", re.MULTILINE) + self._method_pattern = re.compile(rf"({re.escape(self.func_name)})\s*\(", re.MULTILINE) + self._new_class_pattern = re.compile(r"new\s+[a-zA-Z_]\w*\s*$") + self._ident_pattern = re.compile(r"[a-zA-Z_]\w*(?:\.[a-zA-Z_]\w*)*\s*$") + def transform(self, source: str) -> str: """Remove assertions from source code, preserving target function calls. @@ -417,9 +424,7 @@ def _find_hamcrest_assertions(self, source: str) -> list[AssertionMatch]: return assertions # Pattern for Hamcrest: assertThat(actual, is(...)) or assertThat(reason, actual, matcher) - pattern = re.compile(r"(\s*)((?:MatcherAssert\.)?assertThat)\s*\(", re.MULTILINE) - - for match in pattern.finditer(source): + for match in self._hamcrest_pattern.finditer(source): leading_ws = match.group(1) start_pos = match.start() paren_start = match.end() - 1 @@ -510,9 +515,7 @@ def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCa # - method(args) (no receiver) # # Strategy: Find the function name, then look backwards for the receiver - pattern = re.compile(rf"({re.escape(self.func_name)})\s*\(", re.MULTILINE) - - for match in pattern.finditer(content): + for match in self._method_pattern.finditer(content): method_name = match.group(1) method_start = match.start() @@ -530,26 +533,27 @@ def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCa stripped_before = before_method.rstrip() if stripped_before.endswith("."): dot_pos = len(stripped_before) - 1 - before_dot = content[:dot_pos] + before_dot_content = content[:dot_pos] # Check for new ClassName() or new ClassName(args) - stripped_before_dot = before_dot.rstrip() + stripped_before_dot = before_dot_content.rstrip() if stripped_before_dot.endswith(")"): # Find matching opening paren for constructor args close_paren_pos = len(stripped_before_dot) - 1 paren_depth = 1 i = close_paren_pos - 1 while i >= 0 and paren_depth > 0: - if stripped_before_dot[i] == ")": + char = stripped_before_dot[i] + if char == ")": paren_depth += 1 - elif stripped_before_dot[i] == "(": + elif char == "(": paren_depth -= 1 i -= 1 if paren_depth == 0: open_paren_pos = i + 1 # Look for "new ClassName" before the opening paren before_paren = stripped_before_dot[:open_paren_pos].rstrip() - new_match = re.search(r"new\s+[a-zA-Z_]\w*\s*$", before_paren) + new_match = self._new_class_pattern.search(before_paren) if new_match: receiver_start = new_match.start() else: @@ -558,14 +562,20 @@ def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCa receiver_start = open_paren_pos else: # Simple identifier: obj.method() or Class.method() or pkg.Class.method() - ident_match = re.search(r"[a-zA-Z_]\w*(?:\.[a-zA-Z_]\w*)*\s*$", stripped_before_dot) + ident_match = self._ident_pattern.search(stripped_before_dot) if ident_match: receiver_start = ident_match.start() full_call = content[receiver_start:end_pos] - receiver = ( - content[receiver_start:method_start].rstrip(".").strip() if receiver_start < method_start else None - ) + + # Extract receiver more efficiently + if receiver_start < method_start: + receiver_text = content[receiver_start:method_start] + # Remove trailing dot and whitespace in one pass + receiver = receiver_text.rstrip().rstrip(".") + else: + receiver = None + target_calls.append( TargetCall( @@ -637,22 +647,26 @@ def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | N in_string = False string_char = None in_char = False + code_len = len(code) - while pos < len(code) and depth > 0: + while pos < code_len and depth > 0: char = code[pos] - prev_char = code[pos - 1] if pos > 0 else "" # Handle character literals - if char == "'" and not in_string and prev_char != "\\": - in_char = not in_char + if char == "'" and not in_string: + # Check for escape by looking back + if pos == 0 or code[pos - 1] != "\\": + in_char = not in_char # Handle string literals (double quotes) - elif char == '"' and not in_char and prev_char != "\\": - if not in_string: - in_string = True - string_char = char - elif char == string_char: - in_string = False - string_char = None + elif char == '"' and not in_char: + # Check for escape by looking back + if pos == 0 or code[pos - 1] != "\\": + if not in_string: + in_string = True + string_char = char + elif char == string_char: + in_string = False + string_char = None elif not in_string and not in_char: if char == "(": depth += 1