-
Notifications
You must be signed in to change notification settings - Fork 105
Make transformer lazy import #292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,10 @@ | ||
| import logging | ||
| import os | ||
| import time | ||
| from functools import partial | ||
| from typing import Any, List, Optional, Union | ||
|
|
||
| from pydantic import Field | ||
| from transformers import AutoTokenizer, GPT2TokenizerFast | ||
|
|
||
| from agentlab.llm.base_api import AbstractChatModel | ||
| from agentlab.llm.llm_utils import AIMessage, Discussion | ||
|
|
@@ -45,6 +46,14 @@ def __init__(self, model_name, base_model_name, n_retry_server, log_probs): | |
| self.n_retry_server = n_retry_server | ||
| self.log_probs = log_probs | ||
|
|
||
| # Lazy import to avoid heavy transformers import when unused | ||
| try: | ||
| from transformers import AutoTokenizer, GPT2TokenizerFast # type: ignore | ||
| except Exception as e: # pragma: no cover - surfaced only when transformers missing | ||
| raise ImportError( | ||
| "The 'transformers' package is required for HuggingFace models. Install it to use HF backends." | ||
| ) from e | ||
|
|
||
| if base_model_name is None: | ||
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
| else: | ||
|
|
@@ -60,7 +69,7 @@ def __call__( | |
| self, | ||
| messages: list[dict], | ||
| n_samples: int = 1, | ||
| temperature: float = None, | ||
| temperature: Optional[float] = None, | ||
| ) -> Union[AIMessage, List[AIMessage]]: | ||
| """ | ||
| Generate one or more responses for the given messages. | ||
|
|
@@ -85,7 +94,7 @@ def __call__( | |
| except Exception as e: | ||
| if "Conversation roles must alternate" in str(e): | ||
| logging.warning( | ||
| f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role. " | ||
| "Failed to apply the chat template. Maybe because it doesn't support the 'system' role. " | ||
| "Retrying with the 'system' role appended to the 'user' role." | ||
| ) | ||
| messages = _prepend_system_to_first_user(messages) | ||
|
|
@@ -100,7 +109,11 @@ def __call__( | |
| itr = 0 | ||
| while True: | ||
| try: | ||
| temperature = temperature if temperature is not None else self.temperature | ||
| temperature = ( | ||
| temperature | ||
| if temperature is not None | ||
| else getattr(self, "temperature", 0.1) | ||
| ) | ||
| answer = self.llm(prompt, temperature=temperature) | ||
| response = AIMessage(answer) | ||
| if self.log_probs: | ||
|
|
@@ -144,9 +157,52 @@ def _prepend_system_to_first_user(messages, column_remap={}): | |
| for msg in messages: | ||
| if msg[role_key] == human_key: | ||
| # Prepend system content to the first user content | ||
| msg[text_key] = system_content + "\n" + msg[text_key] | ||
| msg[text_key] = str(system_content) + "\n" + str(msg[text_key]) | ||
| # Remove the original system message | ||
| del messages[system_index] | ||
| break # Ensures that only the first user message is modified | ||
|
|
||
| return messages | ||
|
|
||
|
|
||
| class HuggingFaceURLChatModel(HFBaseChatModel): | ||
| """HF backend using a Text Generation Inference (TGI) HTTP endpoint. | ||
|
|
||
| This class is placed here to keep all heavy HF imports optional and only | ||
| loaded when a HF backend is explicitly requested. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_name: str, | ||
| model_url: str, | ||
| base_model_name: Optional[str] = None, | ||
| token: Optional[str] = None, | ||
| temperature: Optional[float] = 1e-1, | ||
| max_new_tokens: Optional[int] = 512, | ||
| n_retry_server: Optional[int] = 4, | ||
| log_probs: Optional[bool] = False, | ||
| ): | ||
| super().__init__(model_name, base_model_name, n_retry_server, log_probs) | ||
| if temperature is not None and temperature < 1e-3: | ||
| logging.warning("Models might behave weirdly when temperature is too low.") | ||
| self.temperature = temperature | ||
|
|
||
| if token is None: | ||
| # support both env var names used elsewhere | ||
| token = os.environ.get("TGI_TOKEN") or os.environ.get("AGENTLAB_MODEL_TOKEN") | ||
|
Comment on lines
+189
to
+193
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing Upper Bound Temperature Validation
Tell me moreWhat is the issue?The Why this mattersHigh temperature values (> 1.0) can lead to extremely random outputs, potentially making the model's responses unusable. Suggested change ∙ Feature PreviewAdd an upper bound check for temperature: if temperature is not None:
if temperature < 1e-3:
logging.warning("Models might behave weirdly when temperature is too low.")
if temperature > 2.0:
logging.warning("High temperature values (>2.0) may result in extremely random outputs.")
self.temperature = temperatureProvide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
|
|
||
| # Lazy import huggingface_hub here to avoid import on non-HF paths | ||
| try: | ||
| from huggingface_hub import InferenceClient # type: ignore | ||
| except Exception as e: # pragma: no cover - surfaced only when package missing | ||
| raise ImportError( | ||
| "The 'huggingface_hub' package is required for HuggingFace URL backends." | ||
| ) from e | ||
|
|
||
| client = InferenceClient(model=model_url, token=token) | ||
| self.llm = partial( | ||
| client.text_generation, | ||
| max_new_tokens=max_new_tokens, | ||
| details=log_probs, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent temperature default values
Tell me more
What is the issue?
The temperature default value of 0.1 is hardcoded in multiple places - in the getattr() call and in the class initialization (1e-1).
Why this matters
Having the same magic number in different formats (0.1 vs 1e-1) in multiple places makes it harder to maintain consistent temperature defaults and increases the risk of inconsistency when updating values.
Suggested change ∙ Feature Preview
Define a class-level constant DEFAULT_TEMPERATURE = 0.1 and use it consistently throughout the code:
Provide feedback to improve future suggestions
💬 Looking for more details? Reply to this comment to chat with Korbit.