Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions src/agentlab/llm/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ def __call__(self, *args, **kwargs):
response = self._call_api(*args, **kwargs)

usage = dict(getattr(response, "usage", {}))
if "prompt_tokens_details" in usage:
usage["cached_tokens"] = usage["prompt_token_details"].cached_tokens
if "input_tokens_details" in usage:
usage["cached_tokens"] = usage["input_tokens_details"].cached_tokens
usage = {f"usage_{k}": v for k, v in usage.items() if isinstance(v, (int, float))}
usage |= {"n_api_calls": 1}
usage |= {"effective_cost": self.get_effective_cost(response)}
Expand Down Expand Up @@ -298,21 +302,29 @@ def get_effective_cost_from_openai_api(self, response) -> float:
Returns:
float: The effective cost calculated from the response.
"""
usage = getattr(response, "usage", {})
prompt_token_details = getattr(response, "prompt_tokens_details", {})

total_input_tokens = getattr(
prompt_token_details, "prompt_tokens", 0
) # Cache read tokens + new input tokens
output_tokens = getattr(usage, "completion_tokens", 0)
cache_read_tokens = getattr(prompt_token_details, "cached_tokens", 0)
usage = getattr(response, "usage", None)
if usage is None:
logging.warning("No usage information found in the response. Defaulting cost to 0.0.")
return 0.0
api_type = "chatcompletion" if hasattr(usage, "prompt_tokens_details") else "response"
if api_type == "chatcompletion":
total_input_tokens = usage.prompt_tokens
output_tokens = usage.completion_tokens
cached_input_tokens = usage.prompt_tokens_details.cached_tokens
non_cached_input_tokens = total_input_tokens - cached_input_tokens
elif api_type == "response":
total_input_tokens = usage.input_tokens
output_tokens = usage.output_tokens
cached_input_tokens = usage.input_tokens_details.cached_tokens
non_cached_input_tokens = total_input_tokens - cached_input_tokens
else:
logging.warning(f"Unsupported API type: {api_type}. Defaulting cost to 0.0.")
return 0.0

non_cached_input_tokens = total_input_tokens - cache_read_tokens
cache_read_cost = self.input_cost * OPENAI_CACHE_PRICING_FACTOR["cache_read_tokens"]

effective_cost = (
self.input_cost * non_cached_input_tokens
+ cache_read_tokens * cache_read_cost
+ cached_input_tokens * cache_read_cost
+ self.output_cost * output_tokens
)
return effective_cost
Expand Down
Loading