Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 20 additions & 22 deletions src/agentlab/llm/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@

import anthropic
import openai
from huggingface_hub import InferenceClient
from openai import NOT_GIVEN, OpenAI

import agentlab.llm.tracking as tracking
from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs
from agentlab.llm.huggingface_utils import HFBaseChatModel
from agentlab.llm.llm_utils import AIMessage, Discussion


Expand Down Expand Up @@ -139,6 +137,8 @@ def make_model(self):
self.model_url = os.environ["AGENTLAB_MODEL_URL"]
if self.token is None:
self.token = os.environ["AGENTLAB_MODEL_TOKEN"]
# Lazy import to avoid importing HF utilities on non-HF paths
from agentlab.llm.huggingface_utils import HuggingFaceURLChatModel

return HuggingFaceURLChatModel(
model_name=self.model_name,
Expand Down Expand Up @@ -438,28 +438,26 @@ def __init__(
)


class HuggingFaceURLChatModel(HFBaseChatModel):
def __init__(
self,
model_name: str,
base_model_name: str,
model_url: str,
token: Optional[str] = None,
temperature: Optional[int] = 1e-1,
max_new_tokens: Optional[int] = 512,
n_retry_server: Optional[int] = 4,
log_probs: Optional[bool] = False,
):
super().__init__(model_name, base_model_name, n_retry_server, log_probs)
if temperature < 1e-3:
logging.warning("Models might behave weirdly when temperature is too low.")
self.temperature = temperature
def __getattr__(name: str):
"""Lazy re-export of optional classes to keep imports light.

This lets users import HuggingFaceURLChatModel from agentlab.llm.chat_api
without importing heavy dependencies unless actually used.

Args:
name: The name of the attribute to retrieve.

Returns:
The requested class or raises AttributeError if not found.

if token is None:
token = os.environ["TGI_TOKEN"]
Raises:
AttributeError: If the requested attribute is not available.
"""
if name == "HuggingFaceURLChatModel":
from agentlab.llm.huggingface_utils import HuggingFaceURLChatModel

client = InferenceClient(model=model_url, token=token)
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens, details=log_probs)
return HuggingFaceURLChatModel
raise AttributeError(name)


class VLLMChatModel(ChatModel):
Expand Down
66 changes: 61 additions & 5 deletions src/agentlab/llm/huggingface_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import os
import time
from functools import partial
from typing import Any, List, Optional, Union

from pydantic import Field
from transformers import AutoTokenizer, GPT2TokenizerFast

from agentlab.llm.base_api import AbstractChatModel
from agentlab.llm.llm_utils import AIMessage, Discussion
Expand Down Expand Up @@ -45,6 +46,14 @@ def __init__(self, model_name, base_model_name, n_retry_server, log_probs):
self.n_retry_server = n_retry_server
self.log_probs = log_probs

# Lazy import to avoid heavy transformers import when unused
try:
from transformers import AutoTokenizer, GPT2TokenizerFast # type: ignore
except Exception as e: # pragma: no cover - surfaced only when transformers missing
raise ImportError(
"The 'transformers' package is required for HuggingFace models. Install it to use HF backends."
) from e

if base_model_name is None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
else:
Expand All @@ -60,7 +69,7 @@ def __call__(
self,
messages: list[dict],
n_samples: int = 1,
temperature: float = None,
temperature: Optional[float] = None,
) -> Union[AIMessage, List[AIMessage]]:
"""
Generate one or more responses for the given messages.
Expand All @@ -85,7 +94,7 @@ def __call__(
except Exception as e:
if "Conversation roles must alternate" in str(e):
logging.warning(
f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role. "
"Failed to apply the chat template. Maybe because it doesn't support the 'system' role. "
"Retrying with the 'system' role appended to the 'user' role."
)
messages = _prepend_system_to_first_user(messages)
Expand All @@ -100,7 +109,11 @@ def __call__(
itr = 0
while True:
try:
temperature = temperature if temperature is not None else self.temperature
temperature = (
temperature
if temperature is not None
else getattr(self, "temperature", 0.1)
)
Comment on lines +112 to +116
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent temperature default values category Readability

Tell me more
What is the issue?

The temperature default value of 0.1 is hardcoded in multiple places - in the getattr() call and in the class initialization (1e-1).

Why this matters

Having the same magic number in different formats (0.1 vs 1e-1) in multiple places makes it harder to maintain consistent temperature defaults and increases the risk of inconsistency when updating values.

Suggested change ∙ Feature Preview

Define a class-level constant DEFAULT_TEMPERATURE = 0.1 and use it consistently throughout the code:

DEFAULT_TEMPERATURE = 0.1

def __init__(..., temperature: Optional[float] = DEFAULT_TEMPERATURE, ...):
    ...

def __call__(...)
    temperature = temperature if temperature is not None else getattr(self, "temperature", DEFAULT_TEMPERATURE)
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.

answer = self.llm(prompt, temperature=temperature)
response = AIMessage(answer)
if self.log_probs:
Expand Down Expand Up @@ -144,9 +157,52 @@ def _prepend_system_to_first_user(messages, column_remap={}):
for msg in messages:
if msg[role_key] == human_key:
# Prepend system content to the first user content
msg[text_key] = system_content + "\n" + msg[text_key]
msg[text_key] = str(system_content) + "\n" + str(msg[text_key])
# Remove the original system message
del messages[system_index]
break # Ensures that only the first user message is modified

return messages


class HuggingFaceURLChatModel(HFBaseChatModel):
"""HF backend using a Text Generation Inference (TGI) HTTP endpoint.

This class is placed here to keep all heavy HF imports optional and only
loaded when a HF backend is explicitly requested.
"""

def __init__(
self,
model_name: str,
model_url: str,
base_model_name: Optional[str] = None,
token: Optional[str] = None,
temperature: Optional[float] = 1e-1,
max_new_tokens: Optional[int] = 512,
n_retry_server: Optional[int] = 4,
log_probs: Optional[bool] = False,
):
super().__init__(model_name, base_model_name, n_retry_server, log_probs)
if temperature is not None and temperature < 1e-3:
logging.warning("Models might behave weirdly when temperature is too low.")
self.temperature = temperature

if token is None:
# support both env var names used elsewhere
token = os.environ.get("TGI_TOKEN") or os.environ.get("AGENTLAB_MODEL_TOKEN")
Comment on lines +189 to +193
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing Upper Bound Temperature Validation category Functionality

Tell me more
What is the issue?

The temperature parameter is set but never validated for the upper bound, while the lower bound is checked for < 1e-3.

Why this matters

High temperature values (> 1.0) can lead to extremely random outputs, potentially making the model's responses unusable.

Suggested change ∙ Feature Preview

Add an upper bound check for temperature:

if temperature is not None:
    if temperature < 1e-3:
        logging.warning("Models might behave weirdly when temperature is too low.")
    if temperature > 2.0:
        logging.warning("High temperature values (>2.0) may result in extremely random outputs.")
self.temperature = temperature
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.


# Lazy import huggingface_hub here to avoid import on non-HF paths
try:
from huggingface_hub import InferenceClient # type: ignore
except Exception as e: # pragma: no cover - surfaced only when package missing
raise ImportError(
"The 'huggingface_hub' package is required for HuggingFace URL backends."
) from e

client = InferenceClient(model=model_url, token=token)
self.llm = partial(
client.text_generation,
max_new_tokens=max_new_tokens,
details=log_probs,
)
21 changes: 18 additions & 3 deletions src/agentlab/llm/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import tiktoken
import yaml
from PIL import Image
from transformers import AutoModel, AutoTokenizer

langchain_community = importlib.util.find_spec("langchain_community")
if langchain_community is not None:
Expand Down Expand Up @@ -512,6 +511,13 @@ def get_tokenizer_old(model_name="openai/gpt-4"):
)
return tiktoken.encoding_for_model("gpt-4")
else:
# Lazy import of transformers only when needed
try:
from transformers import AutoTokenizer # type: ignore
except Exception as e:
raise ImportError(
"The 'transformers' package is required to use non-OpenAI/Azure tokenizers."
) from e
return AutoTokenizer.from_pretrained(model_name)


Expand All @@ -522,6 +528,8 @@ def get_tokenizer(model_name="gpt-4"):
except KeyError:
logging.info(f"Could not find a tokenizer for model {model_name}. Trying HuggingFace.")
try:
from transformers import AutoTokenizer # type: ignore

return AutoTokenizer.from_pretrained(model_name)
except Exception as e:
logging.info(f"Could not find a tokenizer for model {model_name}: {e} Defaulting to gpt-4.")
Expand Down Expand Up @@ -676,8 +684,8 @@ def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False):
retry_messages = []

for key in all_keys:
if not key in content_dict:
if not key in optional_keys:
if key not in content_dict:
if key not in optional_keys:
retry_messages.append(f"Missing the key <{key}> in the answer.")
else:
val = content_dict[key]
Expand All @@ -697,6 +705,13 @@ def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False):


def download_and_save_model(model_name: str, save_dir: str = "."):
# Lazy import of transformers only when explicitly downloading a model
try:
from transformers import AutoModel # type: ignore
except Exception as e:
raise ImportError(
"The 'transformers' package is required to download and save models."
) from e
model = AutoModel.from_pretrained(model_name)
model.save_pretrained(save_dir)
print(f"Model downloaded and saved to {save_dir}")
Expand Down
Loading