diff --git a/requirements.txt b/requirements.txt index 56e086e2..6433d653 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ google-generativeai>=0.8.5 # LLM providers openai>=1.0.0 anthropic>=0.69.0 +litellm>=1.60.0,<2.0.0 # Async utilities aiohttp>=3.8.0 diff --git a/src/model/litellm/__init__.py b/src/model/litellm/__init__.py new file mode 100644 index 00000000..c9077ce6 --- /dev/null +++ b/src/model/litellm/__init__.py @@ -0,0 +1,3 @@ +from .chat import ChatLiteLLM + +__all__ = ["ChatLiteLLM"] diff --git a/src/model/litellm/chat.py b/src/model/litellm/chat.py new file mode 100644 index 00000000..afa2147d --- /dev/null +++ b/src/model/litellm/chat.py @@ -0,0 +1,242 @@ +""" +LiteLLM chat provider — route to 100+ LLM providers via a unified interface. + +Model strings use the ``provider/model`` format, e.g. +``anthropic/claude-sonnet-4-20250514``, ``azure/gpt-4o``, ``bedrock/anthropic.claude-3-haiku``, +``openai/gpt-4o``, ``ollama/llama3``, etc. +See https://docs.litellm.ai/docs/providers for the full list. +""" + +import json +from typing import Any, Dict, List, Optional, Type, Union + +from pydantic import BaseModel, ConfigDict + +from src.message.types import Message +from src.model.openai.serializer import OpenAIChatSerializer +from src.model.types import LLMExtra, LLMResponse +from src.logger import logger + +try: + from src.tool.types import Tool +except ImportError: + Tool = None + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from src.tool.types import Tool + + +class ChatLiteLLM(BaseModel): + """ + A wrapper around the LiteLLM SDK that provides a unified interface + for calling 100+ LLM providers (OpenAI, Anthropic, Google, Azure, + Bedrock, Ollama, etc.) through a single API. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + + model: str + temperature: Optional[float] = 0.7 + max_completion_tokens: Optional[int] = 16384 + api_key: Optional[str] = None + api_base: Optional[str] = None + max_retries: int = 2 + reasoning: Optional[Dict[str, Any]] = None + + @property + def provider(self) -> str: + return "litellm" + + @property + def name(self) -> str: + return str(self.model) + + async def __call__( + self, + messages: List[Message], + tools: Optional[List["Tool"]] = None, + response_format: Optional[Union[Type[BaseModel], BaseModel, Dict]] = None, + stream: bool = False, + **kwargs: Any, + ) -> LLMResponse: + try: + import litellm + except ImportError: + raise ImportError( + "litellm is required for ChatLiteLLM. Install it with: pip install litellm" + ) + + try: + openai_messages = OpenAIChatSerializer.serialize_messages(messages) + + params: Dict[str, Any] = { + "model": self.model, + "messages": openai_messages, + "drop_params": True, + } + + if self.temperature is not None: + params["temperature"] = self.temperature + if self.max_completion_tokens is not None: + params["max_completion_tokens"] = self.max_completion_tokens + if self.api_key: + params["api_key"] = self.api_key + if self.api_base: + params["api_base"] = self.api_base + if self.reasoning: + params.update(self.reasoning) + + if tools: + formatted_tools = OpenAIChatSerializer.serialize_tools(tools) + if formatted_tools: + params["tools"] = formatted_tools + + if response_format: + if isinstance(response_format, type) and issubclass( + response_format, BaseModel + ): + params["response_format"] = ( + OpenAIChatSerializer.serialize_response_format(response_format) + ) + elif isinstance(response_format, BaseModel): + params["response_format"] = ( + OpenAIChatSerializer.serialize_response_format(response_format) + ) + elif isinstance(response_format, dict): + params["response_format"] = response_format + + if stream: + params["stream"] = True + + params.update(kwargs) + + response = await litellm.acompletion(**params) + return await self._format_response(response, tools, response_format) + + except Exception as e: + logger.error(f"LiteLLM error: {e}") + return LLMResponse( + success=False, + message=f"LiteLLM error: {str(e)}", + extra=LLMExtra(data={"error": str(e), "model": self.name}), + ) + + async def _format_response( + self, + response: Any, + tools: Optional[List["Tool"]] = None, + response_format: Optional[Union[Type[BaseModel], BaseModel, Dict]] = None, + ) -> LLMResponse: + try: + resp_dict = ( + response.model_dump() + if hasattr(response, "model_dump") + else dict(response) + ) + + if not resp_dict.get("choices"): + return LLMResponse( + success=False, + message="No choices in response", + extra=LLMExtra(data={"raw_response": resp_dict}), + ) + + choice = resp_dict["choices"][0] + message = choice.get("message", {}) + usage = resp_dict.get("usage") + finish_reason = choice.get("finish_reason") + + if tools and message.get("tool_calls"): + formatted_lines = [] + functions = [] + for tool_call in message["tool_calls"]: + fn = tool_call.get("function", {}) + name = fn.get("name", "") + try: + arguments = json.loads(fn.get("arguments", "{}")) + except json.JSONDecodeError: + arguments = {} + if arguments: + args_str = ", ".join(f"{k}={v!r}" for k, v in arguments.items()) + formatted_lines.append(f"Calling function {name}({args_str})") + else: + formatted_lines.append(f"Calling function {name}()") + functions.append({"name": name, "args": arguments}) + + return LLMResponse( + success=True, + message="\n".join(formatted_lines), + extra=LLMExtra( + data={ + "raw_response": resp_dict, + "functions": functions, + "usage": usage, + "finish_reason": finish_reason, + } + ), + ) + + elif ( + response_format + and isinstance(response_format, type) + and issubclass(response_format, BaseModel) + ): + content = message.get("content", "") + if not content: + return LLMResponse( + success=False, + message="Empty response content from model", + extra=LLMExtra(data={"raw_response": resp_dict}), + ) + try: + data = json.loads(content) + parsed_model = response_format.model_validate(data) + model_dict = parsed_model.model_dump() + field_lines = [f"{k}={v!r}" for k, v in model_dict.items()] + formatted_message = ( + f"Response result:\n\n{response_format.__name__}(\n" + + ",\n".join(f" {line}" for line in field_lines) + + "\n)" + ) + return LLMResponse( + success=True, + message=formatted_message, + extra=LLMExtra( + parsed_model=parsed_model, + data={ + "raw_response": resp_dict, + "usage": usage, + "finish_reason": finish_reason, + }, + ), + ) + except (json.JSONDecodeError, Exception) as e: + return LLMResponse( + success=False, + message=f"Failed to parse structured response: {e}", + extra=LLMExtra(data={"error": str(e), "content": content}), + ) + + else: + content = message.get("content", "") + return LLMResponse( + success=True, + message=content, + extra=LLMExtra( + data={ + "raw_response": resp_dict, + "usage": usage, + "finish_reason": finish_reason, + } + ), + ) + + except Exception as e: + logger.error(f"Failed to format LiteLLM response: {e}") + return LLMResponse( + success=False, + message=f"Failed to format response: {e}", + extra=LLMExtra(data={"error": str(e)}), + ) diff --git a/src/model/manager.py b/src/model/manager.py index c99ad4a1..e9337e8f 100644 --- a/src/model/manager.py +++ b/src/model/manager.py @@ -18,6 +18,7 @@ from src.model.openrouter.chat import ChatOpenRouter from src.model.anthropic.chat import ChatAnthropic from src.model.google.chat import ChatGoogle +from src.model.litellm.chat import ChatLiteLLM from src.message.types import Message from src.logger import logger @@ -37,7 +38,7 @@ class ModelManager: def __init__(self): """Initialize the manager.""" self.models: Dict[str, ModelConfig] = {} - self.model_clients: Dict[str, Union[ChatOpenAI, ResponseOpenAI, TranscribeOpenAI, EmbeddingOpenAI, ChatOpenRouter, ChatAnthropic]] = {} + self.model_clients: Dict[str, Union[ChatOpenAI, ResponseOpenAI, TranscribeOpenAI, EmbeddingOpenAI, ChatOpenRouter, ChatAnthropic, ChatLiteLLM]] = {} # Default parameters self.max_tokens: int = 16384 @@ -66,6 +67,7 @@ async def initialize(self): await self._initialize_openrouter_models() await self._initialize_anthropic_models() await self._initialize_google_models() + await self._initialize_litellm_models() logger.info(f"| Model manager initialized successfully with {len(self.models)} models.") async def _initialize_openai_models(self): @@ -783,9 +785,43 @@ async def _initialize_google_models(self): self.models[config.model_name] = config await self._create_client(config) + async def _initialize_litellm_models(self): + """Initialize LiteLLM models — only if LITELLM_API_KEY or LITELLM_MODEL is set.""" + if not os.getenv("LITELLM_API_KEY") and not os.getenv("LITELLM_MODEL"): + return + + litellm_model = os.getenv("LITELLM_MODEL", "openai/gpt-4o") + config = ModelConfig( + model_name=f"litellm/{litellm_model}", + model_id=litellm_model, + model_type="chat/completions", + provider="litellm", + api_base=os.getenv("LITELLM_API_BASE"), + api_key=os.getenv("LITELLM_API_KEY"), + temperature=self.default_temperature, + max_completion_tokens=self.max_tokens, + supports_streaming=True, + supports_functions=True, + supports_vision=True, + output_version=None, + fallback_model=None, + ) + self.models[config.model_name] = config + await self._create_client(config) + async def _create_client(self, config: ModelConfig) -> None: """Create and cache a client for the given model config.""" - if config.provider == "openrouter": + if config.provider == "litellm": + client = ChatLiteLLM( + model=config.model_id, + api_key=config.api_key, + api_base=config.api_base, + reasoning=config.reasoning if config.reasoning else None, + temperature=config.temperature or self.default_temperature, + max_completion_tokens=config.max_completion_tokens or self.max_tokens, + ) + logger.info(f"| Created ChatLiteLLM client for {config.model_name}") + elif config.provider == "openrouter": # OpenRouter models (only chat/completions supported for now) if config.model_type == "chat/completions": client = ChatOpenRouter( @@ -871,7 +907,7 @@ async def _create_client(self, config: ModelConfig) -> None: async def register_model(self, config: ModelConfig) -> None: """Register a new model configuration.""" - if config.provider not in ["openai", "openrouter", "anthropic", "google"]: + if config.provider not in ["openai", "openrouter", "anthropic", "google", "litellm"]: raise ValueError(f"Only OpenAI, OpenRouter, Anthropic, and Google models are supported. Got provider: {config.provider}") self.models[config.model_name] = config diff --git a/tests/test_litellm.py b/tests/test_litellm.py new file mode 100644 index 00000000..12ef2927 --- /dev/null +++ b/tests/test_litellm.py @@ -0,0 +1,200 @@ +""" +Tests for LiteLLM provider integration. + +Tests cover: +- ChatLiteLLM instantiation and properties +- Manager registration and client creation +- Async completion call with mocked litellm +- drop_params=True default +""" + +import sys +import types +import pytest +from unittest.mock import AsyncMock, MagicMock +from pathlib import Path + +root = str(Path(__file__).resolve().parents[1]) +sys.path.append(root) + +from src.model.litellm.chat import ChatLiteLLM +from src.model.types import ModelConfig + + +class TestChatLiteLLMAttributes: + """Tests for ChatLiteLLM class attributes and properties.""" + + def test_default_temperature(self): + client = ChatLiteLLM(model="openai/gpt-4o") + assert client.temperature == 0.7 + + def test_default_max_completion_tokens(self): + client = ChatLiteLLM(model="openai/gpt-4o") + assert client.max_completion_tokens == 16384 + + def test_provider_property(self): + client = ChatLiteLLM(model="openai/gpt-4o") + assert client.provider == "litellm" + + def test_name_property(self): + client = ChatLiteLLM(model="anthropic/claude-sonnet-4-20250514") + assert client.name == "anthropic/claude-sonnet-4-20250514" + + def test_custom_temperature(self): + client = ChatLiteLLM(model="openai/gpt-4o", temperature=0.2) + assert client.temperature == 0.2 + + def test_api_key_stored(self): + client = ChatLiteLLM(model="openai/gpt-4o", api_key="sk-test") + assert client.api_key == "sk-test" + + def test_api_base_stored(self): + client = ChatLiteLLM( + model="openai/gpt-4o", api_base="http://localhost:4000" + ) + assert client.api_base == "http://localhost:4000" + + +class TestChatLiteLLMCall: + """Tests for ChatLiteLLM.__call__() with mocked litellm.""" + + @pytest.mark.asyncio + async def test_call_dispatches_to_litellm_acompletion(self): + fake_litellm = types.ModuleType("litellm") + mock_response = MagicMock() + mock_response.model_dump.return_value = { + "choices": [ + { + "message": {"role": "assistant", "content": "Hello!"}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + } + fake_litellm.acompletion = AsyncMock(return_value=mock_response) + sys.modules["litellm"] = fake_litellm + + try: + client = ChatLiteLLM(model="openai/gpt-4o", api_key="sk-test") + + from src.message.types import Message + + messages = [Message(role="user", content="Hi")] + result = await client(messages=messages) + + assert result.success is True + assert "Hello!" in result.message + + call_kwargs = fake_litellm.acompletion.call_args + assert call_kwargs.kwargs["model"] == "openai/gpt-4o" + assert call_kwargs.kwargs["drop_params"] is True + assert call_kwargs.kwargs["api_key"] == "sk-test" + finally: + del sys.modules["litellm"] + + @pytest.mark.asyncio + async def test_call_includes_drop_params_true(self): + fake_litellm = types.ModuleType("litellm") + mock_response = MagicMock() + mock_response.model_dump.return_value = { + "choices": [ + { + "message": {"role": "assistant", "content": "OK"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 5, "completion_tokens": 1, "total_tokens": 6}, + } + fake_litellm.acompletion = AsyncMock(return_value=mock_response) + sys.modules["litellm"] = fake_litellm + + try: + client = ChatLiteLLM(model="anthropic/claude-haiku-4-5") + from src.message.types import Message + + messages = [Message(role="user", content="Say OK")] + await client(messages=messages) + + call_kwargs = fake_litellm.acompletion.call_args.kwargs + assert call_kwargs["drop_params"] is True + finally: + del sys.modules["litellm"] + + @pytest.mark.asyncio + async def test_call_handles_tool_calls(self): + fake_litellm = types.ModuleType("litellm") + mock_response = MagicMock() + mock_response.model_dump.return_value = { + "choices": [ + { + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "function": { + "name": "get_weather", + "arguments": '{"city": "London"}', + }, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18}, + } + fake_litellm.acompletion = AsyncMock(return_value=mock_response) + sys.modules["litellm"] = fake_litellm + + try: + client = ChatLiteLLM(model="openai/gpt-4o") + from src.message.types import Message + + mock_tool = MagicMock() + messages = [Message(role="user", content="Weather in London?")] + result = await client(messages=messages, tools=[mock_tool]) + + assert result.success is True + assert "get_weather" in result.message + assert result.extra.data["functions"][0]["name"] == "get_weather" + finally: + del sys.modules["litellm"] + + @pytest.mark.asyncio + async def test_call_returns_failure_on_import_error(self): + if "litellm" in sys.modules: + del sys.modules["litellm"] + + client = ChatLiteLLM(model="openai/gpt-4o") + from src.message.types import Message + + messages = [Message(role="user", content="Hi")] + + with pytest.raises(ImportError, match="litellm is required"): + await client(messages=messages) + + +class TestManagerLiteLLMRegistration: + """Tests for litellm provider registration in ModelManager.""" + + def test_litellm_in_allowed_providers(self): + from src.model.manager import ModelManager + + ModelManager() + try: + ModelConfig( + model_name="litellm/test", + model_id="openai/gpt-4o", + model_type="chat/completions", + provider="litellm", + ) + # Should not raise ValueError for provider + # (register_model validates the provider name) + except ValueError as e: + if "Only OpenAI" in str(e): + pytest.fail("litellm should be an allowed provider")