diff --git a/clarifai/runners/models/dummy_openai_model.py b/clarifai/runners/models/dummy_openai_model.py index 1d966f68..e6a43ebf 100644 --- a/clarifai/runners/models/dummy_openai_model.py +++ b/clarifai/runners/models/dummy_openai_model.py @@ -25,10 +25,50 @@ def create(self, **kwargs): else: return MockResponse(**kwargs) + class Models: + class Model: + def __init__(self, model_id): + self.id = model_id + + def list(self, **kwargs): + """Mock list method for models.""" + class ModelList: + def __init__(self): + self.data = [MockOpenAIClient.Models.Model("dummy-model")] + return ModelList() + + class Images: + def generate(self, **kwargs): + """Mock generate method for images.""" + # Return a simple mock image response + class ImageResponse: + def __init__(self): + self.data = [{"url": "https://example.com/image.png"}] + + def model_dump_json(self): + return json.dumps({"data": self.data}) + return ImageResponse() + + class Embeddings: + def create(self, **kwargs): + """Mock create method for embeddings.""" + # Return a simple mock embedding response + class EmbeddingResponse: + def __init__(self): + self.data = [{"embedding": [0.1, 0.2, 0.3]}] + self.usage = {"prompt_tokens": 10, "total_tokens": 10} + + def model_dump_json(self): + return json.dumps({"data": self.data, "usage": self.usage}) + return EmbeddingResponse() + def __init__(self): self.chat = self # Make self.chat point to self for compatibility self.completions = self.Completions() # For compatibility with some clients self.responses = self.Responses() # For responses API + self.models = self.Models() # For models.list() compatibility + self.images = self.Images() # For images.generate() compatibility + self.embeddings = self.Embeddings() # For embeddings.create() compatibility class MockCompletion: diff --git a/clarifai/runners/models/openai_class.py b/clarifai/runners/models/openai_class.py index 22a11427..3645ede8 100644 --- a/clarifai/runners/models/openai_class.py +++ b/clarifai/runners/models/openai_class.py @@ -2,8 +2,10 @@ from typing import Any, Dict, Iterator +import httpx from clarifai_grpc.grpc.api.status import status_code_pb2 from pydantic_core import from_json, to_json +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential from clarifai.runners.models.model_class import ModelClass from clarifai.utils.logging import logger @@ -42,12 +44,57 @@ def __init__(self) -> None: raise NotImplementedError("Subclasses must set the 'client' class attribute") if self.model is None: try: - self.model = self.client.models.list().data[0].id + self.model = self._retry_models_list().data[0].id except Exception as e: raise NotImplementedError( "Subclasses must set the 'model' class attribute or ensure the client can list models" ) from e + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type(httpx.ConnectError), + ) + def _retry_models_list(self, **kwargs): + """List models with retry logic.""" + return self.client.models.list(**kwargs) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type(httpx.ConnectError), + ) + def _retry_chat_completions_create(self, **kwargs): + """Create chat completions with retry logic.""" + return self.client.chat.completions.create(**kwargs) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type(httpx.ConnectError), + ) + def _retry_images_generate(self, **kwargs): + """Generate images with retry logic.""" + return self.client.images.generate(**kwargs) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type(httpx.ConnectError), + ) + def _retry_embeddings_create(self, **kwargs): + """Create embeddings with retry logic.""" + return self.client.embeddings.create(**kwargs) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type(httpx.ConnectError), + ) + def _retry_responses_create(self, **kwargs): + """Create responses with retry logic.""" + return self.client.responses.create(**kwargs) + def _create_completion_args(self, params: Dict[str, Any]) -> Dict[str, Any]: """Create the completion arguments dictionary from parameters. @@ -72,7 +119,7 @@ def _create_completion_args(self, params: Dict[str, Any]) -> Dict[str, Any]: def handle_liveness_probe(self) -> bool: """Handle liveness probe by checking if the client can list models.""" try: - _ = self.client.models.list() + _ = self._retry_models_list() return True except Exception as e: logger.error(f"Liveness probe failed: {e}", exc_info=True) @@ -81,7 +128,7 @@ def handle_liveness_probe(self) -> bool: def handle_readiness_probe(self) -> bool: """Handle readiness probe by checking if the client can list models.""" try: - _ = self.client.models.list() + _ = self._retry_models_list() return True except Exception as e: logger.error(f"Readiness probe failed: {e}", exc_info=True) @@ -135,7 +182,7 @@ def _set_usage(self, resp): def _handle_chat_completions(self, request_data: Dict[str, Any]): """Handle chat completion requests.""" completion_args = self._create_completion_args(request_data) - completion = self.client.chat.completions.create(**completion_args) + completion = self._retry_chat_completions_create(**completion_args) self._set_usage(completion) return completion @@ -143,21 +190,21 @@ def _handle_images_generate(self, request_data: Dict[str, Any]): """Handle image generation requests.""" image_args = {**request_data} image_args.update({"model": self.model}) - response = self.client.images.generate(**image_args) + response = self._retry_images_generate(**image_args) return response def _handle_embeddings(self, request_data: Dict[str, Any]): """Handle embedding requests.""" embedding_args = {**request_data} embedding_args.update({"model": self.model}) - response = self.client.embeddings.create(**embedding_args) + response = self._retry_embeddings_create(**embedding_args) return response def _handle_responses(self, request_data: Dict[str, Any]): """Handle response requests.""" response_args = {**request_data} response_args.update({"model": self.model}) - response = self.client.responses.create(**response_args) + response = self._retry_responses_create(**response_args) self._set_usage(response) return response @@ -257,7 +304,7 @@ def openai_stream_transport(self, msg: str) -> Iterator[str]: yield chunk.model_dump_json() else: completion_args = self._create_completion_args(request_data) - stream_completion = self.client.chat.completions.create(**completion_args) + stream_completion = self._retry_chat_completions_create(**completion_args) for chunk in stream_completion: self._set_usage(chunk) yield chunk.model_dump_json() diff --git a/requirements.txt b/requirements.txt index e80d8279..876f7990 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ psutil==7.0.0 pygments>=2.19.2 pydantic_core==2.33.2 packaging==25.0 +tenacity>=8.2.3