diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index f497894e..9ee68838 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -8,12 +8,10 @@ import anthropic import openai -from huggingface_hub import InferenceClient from openai import NOT_GIVEN, OpenAI import agentlab.llm.tracking as tracking from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs -from agentlab.llm.huggingface_utils import HFBaseChatModel from agentlab.llm.llm_utils import AIMessage, Discussion @@ -139,6 +137,8 @@ def make_model(self): self.model_url = os.environ["AGENTLAB_MODEL_URL"] if self.token is None: self.token = os.environ["AGENTLAB_MODEL_TOKEN"] + # Lazy import to avoid importing HF utilities on non-HF paths + from agentlab.llm.huggingface_utils import HuggingFaceURLChatModel return HuggingFaceURLChatModel( model_name=self.model_name, @@ -438,28 +438,26 @@ def __init__( ) -class HuggingFaceURLChatModel(HFBaseChatModel): - def __init__( - self, - model_name: str, - base_model_name: str, - model_url: str, - token: Optional[str] = None, - temperature: Optional[int] = 1e-1, - max_new_tokens: Optional[int] = 512, - n_retry_server: Optional[int] = 4, - log_probs: Optional[bool] = False, - ): - super().__init__(model_name, base_model_name, n_retry_server, log_probs) - if temperature < 1e-3: - logging.warning("Models might behave weirdly when temperature is too low.") - self.temperature = temperature +def __getattr__(name: str): + """Lazy re-export of optional classes to keep imports light. + + This lets users import HuggingFaceURLChatModel from agentlab.llm.chat_api + without importing heavy dependencies unless actually used. + + Args: + name: The name of the attribute to retrieve. + + Returns: + The requested class or raises AttributeError if not found. - if token is None: - token = os.environ["TGI_TOKEN"] + Raises: + AttributeError: If the requested attribute is not available. + """ + if name == "HuggingFaceURLChatModel": + from agentlab.llm.huggingface_utils import HuggingFaceURLChatModel - client = InferenceClient(model=model_url, token=token) - self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens, details=log_probs) + return HuggingFaceURLChatModel + raise AttributeError(name) class VLLMChatModel(ChatModel): diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index 8c3e862d..64a7dab1 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -1,9 +1,10 @@ import logging +import os import time +from functools import partial from typing import Any, List, Optional, Union from pydantic import Field -from transformers import AutoTokenizer, GPT2TokenizerFast from agentlab.llm.base_api import AbstractChatModel from agentlab.llm.llm_utils import AIMessage, Discussion @@ -45,6 +46,14 @@ def __init__(self, model_name, base_model_name, n_retry_server, log_probs): self.n_retry_server = n_retry_server self.log_probs = log_probs + # Lazy import to avoid heavy transformers import when unused + try: + from transformers import AutoTokenizer, GPT2TokenizerFast # type: ignore + except Exception as e: # pragma: no cover - surfaced only when transformers missing + raise ImportError( + "The 'transformers' package is required for HuggingFace models. Install it to use HF backends." + ) from e + if base_model_name is None: self.tokenizer = AutoTokenizer.from_pretrained(model_name) else: @@ -60,7 +69,7 @@ def __call__( self, messages: list[dict], n_samples: int = 1, - temperature: float = None, + temperature: Optional[float] = None, ) -> Union[AIMessage, List[AIMessage]]: """ Generate one or more responses for the given messages. @@ -85,7 +94,7 @@ def __call__( except Exception as e: if "Conversation roles must alternate" in str(e): logging.warning( - f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role. " + "Failed to apply the chat template. Maybe because it doesn't support the 'system' role. " "Retrying with the 'system' role appended to the 'user' role." ) messages = _prepend_system_to_first_user(messages) @@ -100,7 +109,11 @@ def __call__( itr = 0 while True: try: - temperature = temperature if temperature is not None else self.temperature + temperature = ( + temperature + if temperature is not None + else getattr(self, "temperature", 0.1) + ) answer = self.llm(prompt, temperature=temperature) response = AIMessage(answer) if self.log_probs: @@ -144,9 +157,52 @@ def _prepend_system_to_first_user(messages, column_remap={}): for msg in messages: if msg[role_key] == human_key: # Prepend system content to the first user content - msg[text_key] = system_content + "\n" + msg[text_key] + msg[text_key] = str(system_content) + "\n" + str(msg[text_key]) # Remove the original system message del messages[system_index] break # Ensures that only the first user message is modified return messages + + +class HuggingFaceURLChatModel(HFBaseChatModel): + """HF backend using a Text Generation Inference (TGI) HTTP endpoint. + + This class is placed here to keep all heavy HF imports optional and only + loaded when a HF backend is explicitly requested. + """ + + def __init__( + self, + model_name: str, + model_url: str, + base_model_name: Optional[str] = None, + token: Optional[str] = None, + temperature: Optional[float] = 1e-1, + max_new_tokens: Optional[int] = 512, + n_retry_server: Optional[int] = 4, + log_probs: Optional[bool] = False, + ): + super().__init__(model_name, base_model_name, n_retry_server, log_probs) + if temperature is not None and temperature < 1e-3: + logging.warning("Models might behave weirdly when temperature is too low.") + self.temperature = temperature + + if token is None: + # support both env var names used elsewhere + token = os.environ.get("TGI_TOKEN") or os.environ.get("AGENTLAB_MODEL_TOKEN") + + # Lazy import huggingface_hub here to avoid import on non-HF paths + try: + from huggingface_hub import InferenceClient # type: ignore + except Exception as e: # pragma: no cover - surfaced only when package missing + raise ImportError( + "The 'huggingface_hub' package is required for HuggingFace URL backends." + ) from e + + client = InferenceClient(model=model_url, token=token) + self.llm = partial( + client.text_generation, + max_new_tokens=max_new_tokens, + details=log_probs, + ) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 10013b72..187fbd4c 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -18,7 +18,6 @@ import tiktoken import yaml from PIL import Image -from transformers import AutoModel, AutoTokenizer langchain_community = importlib.util.find_spec("langchain_community") if langchain_community is not None: @@ -512,6 +511,13 @@ def get_tokenizer_old(model_name="openai/gpt-4"): ) return tiktoken.encoding_for_model("gpt-4") else: + # Lazy import of transformers only when needed + try: + from transformers import AutoTokenizer # type: ignore + except Exception as e: + raise ImportError( + "The 'transformers' package is required to use non-OpenAI/Azure tokenizers." + ) from e return AutoTokenizer.from_pretrained(model_name) @@ -522,6 +528,8 @@ def get_tokenizer(model_name="gpt-4"): except KeyError: logging.info(f"Could not find a tokenizer for model {model_name}. Trying HuggingFace.") try: + from transformers import AutoTokenizer # type: ignore + return AutoTokenizer.from_pretrained(model_name) except Exception as e: logging.info(f"Could not find a tokenizer for model {model_name}: {e} Defaulting to gpt-4.") @@ -676,8 +684,8 @@ def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False): retry_messages = [] for key in all_keys: - if not key in content_dict: - if not key in optional_keys: + if key not in content_dict: + if key not in optional_keys: retry_messages.append(f"Missing the key <{key}> in the answer.") else: val = content_dict[key] @@ -697,6 +705,13 @@ def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False): def download_and_save_model(model_name: str, save_dir: str = "."): + # Lazy import of transformers only when explicitly downloading a model + try: + from transformers import AutoModel # type: ignore + except Exception as e: + raise ImportError( + "The 'transformers' package is required to download and save models." + ) from e model = AutoModel.from_pretrained(model_name) model.save_pretrained(save_dir) print(f"Model downloaded and saved to {save_dir}")