Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def optimize_python_code_line_profiler(
console.rule()

# Set python_version for backward compatibility with Python, or use language_version
python_version = language_version if language_version else platform.python_version()
python_version = language_version or platform.python_version()

payload = {
"source_code": source_code,
Expand Down Expand Up @@ -868,7 +868,7 @@ def get_optimization_review(
"replay_tests": replay_tests,
"speedup": f"{(100 * float(explanation.speedup)):.2f}%",
"loop_count": explanation.winning_benchmarking_test_results.number_of_loops(),
"benchmark_details": explanation.benchmark_details if explanation.benchmark_details else None,
"benchmark_details": explanation.benchmark_details or None,
"optimized_runtime": humanize_runtime(explanation.best_runtime_ns),
"original_runtime": humanize_runtime(explanation.original_runtime_ns),
"codeflash_version": codeflash_version,
Expand Down
4 changes: 2 additions & 2 deletions codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,7 +1436,7 @@ def _collect_numerical_imports(tree: ast.Module) -> tuple[set[str], set[str]]:
module_root = alias.name.split(".")[0]
if module_root in NUMERICAL_MODULES:
# Use the alias if present, otherwise the module name
name = alias.asname if alias.asname else alias.name.split(".")[0]
name = alias.asname or alias.name.split(".")[0]
numerical_names.add(name)
modules_used.add(module_root)
elif isinstance(node, ast.ImportFrom) and node.module:
Expand All @@ -1448,7 +1448,7 @@ def _collect_numerical_imports(tree: ast.Module) -> tuple[set[str], set[str]]:
# Can't track star imports, but mark the module as numerical
numerical_names.add(module_root)
else:
name = alias.asname if alias.asname else alias.name
name = alias.asname or alias.name
numerical_names.add(name)
modules_used.add(module_root)

Expand Down
2 changes: 1 addition & 1 deletion codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def _get_existing_names(original_declarations: list, analyzer: TreeSitterAnalyze
if imp.default_import:
existing_names.add(imp.default_import)
for name, alias in imp.named_imports:
existing_names.add(alias if alias else name)
existing_names.add(alias or name)
if imp.namespace_import:
existing_names.add(imp.namespace_import)

Expand Down
2 changes: 1 addition & 1 deletion codeflash/code_utils/codeflash_wrap_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def extract_test_context_from_env() -> tuple[str, str | None, str]:
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]

if test_module and test_function:
return (test_module, test_class if test_class else None, test_function)
return (test_module, test_class or None, test_function)

raise RuntimeError(
"Test context environment variables not set - ensure tests are run through codeflash test runner"
Expand Down
2 changes: 1 addition & 1 deletion codeflash/code_utils/config_js.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def parse_package_json_config(package_json_path: Path) -> tuple[dict[str, Any],
config["formatter_cmds"] = codeflash_config["formatterCmds"]
else:
detected_formatter = detect_formatter(project_root, package_data)
config["formatter_cmds"] = detected_formatter if detected_formatter else []
config["formatter_cmds"] = detected_formatter or []

# Parse optional config values from codeflash section
if codeflash_config.get("benchmarksRoot"):
Expand Down
10 changes: 5 additions & 5 deletions codeflash/code_utils/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_current_branch(repo: Repo | None = None) -> str:
:return: The name of the current branch, or "main" if HEAD is detached or
the branch cannot be determined.
"""
repository: Repo = repo if repo else git.Repo(search_parent_directories=True)
repository: Repo = repo or git.Repo(search_parent_directories=True)

# Check if HEAD is detached (active_branch will be None)
if repository.head.is_detached:
Expand Down Expand Up @@ -106,12 +106,12 @@ def get_current_branch(repo: Repo | None = None) -> str:


def get_remote_url(repo: Repo | None = None, git_remote: str | None = "origin") -> str:
repository: Repo = repo if repo else git.Repo(search_parent_directories=True)
repository: Repo = repo or git.Repo(search_parent_directories=True)
return repository.remote(name=git_remote).url


def get_git_remotes(repo: Repo) -> list[str]:
repository: Repo = repo if repo else git.Repo(search_parent_directories=True)
repository: Repo = repo or git.Repo(search_parent_directories=True)
return [remote.name for remote in repository.remotes]


Expand All @@ -128,7 +128,7 @@ def get_repo_owner_and_name(repo: Repo | None = None, git_remote: str | None = "


def git_root_dir(repo: Repo | None = None) -> Path:
repository: Repo = repo if repo else git.Repo(search_parent_directories=True)
repository: Repo = repo or git.Repo(search_parent_directories=True)
return Path(repository.working_dir)


Expand Down Expand Up @@ -199,7 +199,7 @@ def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None:
if "PR_NUMBER" not in os.environ:
return None
try:
repository: Repo = repo if repo else git.Repo(search_parent_directories=True)
repository: Repo = repo or git.Repo(search_parent_directories=True)
last_commit = repository.head.commit
except Exception:
logger.exception("Failed to get last commit author.")
Expand Down
Loading
Loading