diff --git a/.env_integration_tests.example b/.env_integration_tests.example index b73fe86..d51c852 100644 --- a/.env_integration_tests.example +++ b/.env_integration_tests.example @@ -16,3 +16,8 @@ CLOUD_SDK_CFG_DESTINATION_DEFAULT_IDENTITYZONE=your-identity-zone-here CLOUD_SDK_CFG_SDM_DEFAULT_URI=https://your-sdm-api-uri-here CLOUD_SDK_CFG_SDM_DEFAULT_UAA='{"url":"https://your-auth-url","clientid":"your-client-id","clientsecret":"your-client-secret","identityzone":"your-identity-zone"}' + +CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_APPLICATION_URL=https://your-agent-memory-api-url-here +CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA_URL=https://your-auth-url-here +CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA_CLIENTID=your-client-id-here +CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA_CLIENTSECRET=your-client-secret-here diff --git a/docs/INTEGRATION_TESTS.md b/docs/INTEGRATION_TESTS.md index 28a9b9f..34f6b3c 100644 --- a/docs/INTEGRATION_TESTS.md +++ b/docs/INTEGRATION_TESTS.md @@ -9,10 +9,12 @@ Integration tests verify that the SDK modules work correctly with real external ## Prerequisites ### Required Tools + - **Python 3.11+**: Required for running the tests - **uv**: Package manager for dependency management ### Install Dependencies + ```bash # Install all dependencies including test dependencies uv sync --all-extras @@ -62,6 +64,18 @@ CLOUD_SDK_CFG_DESTINATION_DEFAULT_URI=https://your-destination-configuration-uri CLOUD_SDK_CFG_DESTINATION_DEFAULT_IDENTITYZONE=your-identity-zone-here ``` +### Agent Memory Integration Tests + +For Agent Memory integration tests, configure the following variables in `.env_integration_tests`: + +```bash +# Agent Memory Configuration +CLOUD_SDK_CFG_AGENT_MEMORY_DEFAULT_URL=https://your-agent-memory-api-url +CLOUD_SDK_CFG_AGENT_MEMORY_DEFAULT_AUTH_URL=https://your-auth-url +CLOUD_SDK_CFG_AGENT_MEMORY_DEFAULT_CLIENTID=your-client-id +CLOUD_SDK_CFG_AGENT_MEMORY_DEFAULT_CLIENTSECRET=your-client-secret +``` + ## Running Integration Tests ```bash @@ -72,6 +86,7 @@ uv run pytest tests/ -m integration -v uv run pytest tests/core/integration/auditlog -v uv run pytest tests/objectstore/integration/ -v uv run pytest tests/destination/integration/ -v +uv run pytest tests/agent_memory/integration/ -v ``` ### BDD Scenarios diff --git a/src/sap_cloud_sdk/agent_memory/__init__.py b/src/sap_cloud_sdk/agent_memory/__init__.py new file mode 100644 index 0000000..98195c8 --- /dev/null +++ b/src/sap_cloud_sdk/agent_memory/__init__.py @@ -0,0 +1,78 @@ +"""SAP Cloud SDK for Python — Agent Memory module. + +The ``create_client()`` function auto-detects credentials from a mounted volume +or ``CLOUD_SDK_CFG_AGENT_MEMORY_DEFAULT_*`` environment variables. + +Usage:: + + from sap_cloud_sdk.agent_memory import create_client + + client = create_client() + memories = client.list_memories(agent_id="my-agent", invoker_id="user-123") +""" + +from typing import Optional + +from sap_cloud_sdk.agent_memory._http_transport import HttpTransport +from sap_cloud_sdk.agent_memory.client import AgentMemoryClient +from sap_cloud_sdk.agent_memory.config import AgentMemoryConfig, _load_config_from_env +from sap_cloud_sdk.agent_memory.exceptions import ( + AgentMemoryConfigError, + AgentMemoryError, + AgentMemoryHttpError, + AgentMemoryNotFoundError, + AgentMemoryValidationError, +) +from sap_cloud_sdk.agent_memory._models import ( + Memory, + Message, + MessageRole, + RetentionConfig, + SearchResult, +) +from sap_cloud_sdk.agent_memory.utils._odata import FilterDefinition + + +def create_client(*, config: Optional[AgentMemoryConfig] = None) -> AgentMemoryClient: + """Create an :class:`AgentMemoryClient` with automatic credential detection. + + Args: + config: Optional explicit configuration. If ``None``, credentials are + loaded from the mounted volume at + ``/etc/secrets/appfnd/hana-agent-memory/default/`` or from + ``CLOUD_SDK_CFG_AGENT_MEMORY_DEFAULT_*`` environment variables. + + Returns: + A ready-to-use :class:`AgentMemoryClient`. + + Raises: + AgentMemoryConfigError: If configuration is missing or invalid. + """ + try: + resolved_config = config if config is not None else _load_config_from_env() + transport = HttpTransport(resolved_config) + return AgentMemoryClient(transport) + except AgentMemoryConfigError: + raise + except Exception as exc: + raise AgentMemoryConfigError( + f"Failed to create Agent Memory client: {exc}" + ) from exc + + +__all__ = [ + "AgentMemoryClient", + "AgentMemoryConfig", + "AgentMemoryError", + "AgentMemoryConfigError", + "AgentMemoryHttpError", + "AgentMemoryNotFoundError", + "AgentMemoryValidationError", + "FilterDefinition", + "Memory", + "Message", + "MessageRole", + "RetentionConfig", + "SearchResult", + "create_client", +] diff --git a/src/sap_cloud_sdk/agent_memory/_endpoints.py b/src/sap_cloud_sdk/agent_memory/_endpoints.py new file mode 100644 index 0000000..a241900 --- /dev/null +++ b/src/sap_cloud_sdk/agent_memory/_endpoints.py @@ -0,0 +1,44 @@ +"""Agent Memory API endpoint path constants. + +All endpoint paths are centralised here so that migrating to a new API version +requires changes in only this one file. + +Current API version: v1 + - Memories CRUD + search: /v1/memories + - Messages CRUD: /v1/messages + - Admin (retention): /v1/admin/retentionConfig +""" + +from __future__ import annotations + +# ── Base path ────────────────────────────────────────────────────────────────── + +BASE_PATH = "/v1" + +# ── Memory endpoints ────────────────────────────────────────────────────────── + +MEMORIES = f"{BASE_PATH}/memories" +# POST MEMORIES → create memory +# GET MEMORIES → list memories (with OData $filter / $top / $skip) +# GET MEMORIES({id}) → get memory +# PATCH MEMORIES({id}) → update memory +# DELETE MEMORIES({id}) → delete memory + +MEMORY_SEARCH = f"{MEMORIES}/search" +# POST MEMORY_SEARCH → semantic similarity search + +# ── Message endpoints ───────────────────────────────────────────────────────── + +MESSAGES = f"{BASE_PATH}/messages" +# POST MESSAGES → create message +# GET MESSAGES → list messages (with OData $filter / $top / $skip) +# GET MESSAGES({id}) → get message +# DELETE MESSAGES({id}) → delete message (not updatable) + +# ── Admin endpoints ─────────────────────────────────────────────────────────── + +ADMIN_BASE_PATH = "/v1/admin" + +RETENTION_CONFIG = f"{ADMIN_BASE_PATH}/retentionConfig" +# GET RETENTION_CONFIG → get singleton retention config +# PATCH RETENTION_CONFIG → update retention policy diff --git a/src/sap_cloud_sdk/agent_memory/_http_transport.py b/src/sap_cloud_sdk/agent_memory/_http_transport.py new file mode 100644 index 0000000..e8c8017 --- /dev/null +++ b/src/sap_cloud_sdk/agent_memory/_http_transport.py @@ -0,0 +1,221 @@ +"""HTTP transport for the Agent Memory service. + +Handles OAuth2 ``client_credentials`` token acquisition with lazy, +expiry-aware caching. If ``token_url`` is not configured, requests are +sent unauthenticated — expected for local development environments. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timedelta +from typing import Any, Optional +from urllib.parse import quote, urlencode + +import requests +from oauthlib.oauth2 import BackendApplicationClient +from requests.exceptions import RequestException, Timeout +from requests_oauthlib import OAuth2Session + +from sap_cloud_sdk.agent_memory.config import AgentMemoryConfig +from sap_cloud_sdk.agent_memory.exceptions import ( + AgentMemoryHttpError, + AgentMemoryNotFoundError, +) + +logger = logging.getLogger(__name__) + +_TOKEN_EXPIRY_BUFFER_SECONDS = 60 + + +class HttpTransport: + """Internal HTTP transport for the Agent Memory service. + + Manages OAuth2 token lifecycle (lazy acquire + expiry-aware caching) and + attaches the ``Authorization`` header to every request automatically via + ``OAuth2Session``. In no-auth mode (no ``token_url``), a plain + ``requests.Session`` is used instead. + + Args: + config: Service configuration. + """ + + def __init__(self, config: AgentMemoryConfig) -> None: + self._config = config + self._oauth: Optional[OAuth2Session] = None + self._plain_session: Optional[requests.Session] = None + self._token_expires_at: Optional[datetime] = None + + def close(self) -> None: + """Close the underlying HTTP session(s) and release resources.""" + if self._oauth is not None: + self._oauth.close() + self._oauth = None + if self._plain_session is not None: + self._plain_session.close() + self._plain_session = None + + # ── Public HTTP methods ──────────────────────────────────────────────────── + + def get(self, path: str, params: Optional[dict[str, Any]] = None) -> dict[str, Any]: + """Perform a GET request. + + Args: + path: API path (appended to ``base_url``). + params: Optional query parameters. + + Returns: + Parsed JSON response body. + + Raises: + AgentMemoryHttpError: On HTTP errors or network failures. + AgentMemoryNotFoundError: If the server returns 404. + """ + return self._request("GET", path, params=params) + + def post(self, path: str, json: Optional[dict[str, Any]] = None) -> dict[str, Any]: + """Perform a POST request. + + Args: + path: API path (appended to ``base_url``). + json: Optional request body dict (serialised to JSON). + + Returns: + Parsed JSON response body. Returns an empty dict for 204 responses. + + Raises: + AgentMemoryHttpError: On HTTP errors or network failures. + AgentMemoryNotFoundError: If the server returns 404. + """ + return self._request("POST", path, json=json) + + def patch(self, path: str, json: Optional[dict[str, Any]] = None) -> dict[str, Any]: + """Perform a PATCH request. + + Args: + path: API path (appended to ``base_url``). + json: Optional request body dict (serialised to JSON). + + Returns: + Parsed JSON response body. Returns an empty dict for 204 responses. + + Raises: + AgentMemoryHttpError: On HTTP errors or network failures. + AgentMemoryNotFoundError: If the server returns 404. + """ + return self._request("PATCH", path, json=json) + + def delete(self, path: str) -> None: + """Perform a DELETE request. + + Args: + path: API path (appended to ``base_url``). + + Raises: + AgentMemoryHttpError: On HTTP errors or network failures. + AgentMemoryNotFoundError: If the server returns 404. + """ + self._request("DELETE", path) + + # ── Internal helpers ─────────────────────────────────────────────────────── + + def _get_session(self) -> requests.Session: + """Return a session ready to make requests. + + In no-auth mode, returns a plain ``requests.Session`` (created once). + In OAuth2 mode, returns an ``OAuth2Session`` with a valid token, + fetching or refreshing the token if needed. + """ + if not self._config.token_url: + if self._plain_session is None: + self._plain_session = requests.Session() + return self._plain_session + + if ( + self._oauth is not None + and self._token_expires_at is not None + and datetime.now() < self._token_expires_at + ): + return self._oauth + + self._oauth = self._fetch_token() + return self._oauth + + def _fetch_token(self) -> OAuth2Session: + """Acquire a new OAuth2 ``client_credentials`` token. + + Returns: + An ``OAuth2Session`` with a valid token attached. + + Raises: + AgentMemoryHttpError: If the token endpoint returns an error or is unreachable. + """ + try: + client = BackendApplicationClient(client_id=self._config.client_id) + oauth = OAuth2Session(client=client) + token = oauth.fetch_token( + token_url=self._config.token_url, + client_id=self._config.client_id, + client_secret=self._config.client_secret, + timeout=self._config.timeout, + ) + except Exception as exc: + raise AgentMemoryHttpError(f"Failed to obtain OAuth2 token: {exc}") from exc + + expires_in: int = token.get("expires_in", 3600) + self._token_expires_at = datetime.now() + timedelta( + seconds=expires_in - _TOKEN_EXPIRY_BUFFER_SECONDS + ) + + if self._oauth is not None: + self._oauth.close() + + logger.debug( + "Obtained new Agent Memory OAuth2 token (expires in %ds)", expires_in + ) + return oauth + + def _request(self, method: str, path: str, **kwargs: Any) -> dict[str, Any]: + """Execute an HTTP request using the appropriate session.""" + logger.debug("%s %s", method, path) + + url = f"{self._config.base_url}{path}" + if "params" in kwargs: + raw_params: dict[str, Any] = kwargs.pop("params") + if raw_params: + url = f"{url}?{urlencode(raw_params, quote_via=quote)}" + + session = self._get_session() + headers = {"Content-Type": "application/json"} + + try: + response = session.request( + method, url, headers=headers, timeout=self._config.timeout, **kwargs + ) + except Timeout as exc: + raise AgentMemoryHttpError(f"Request timed out: {method} {path}") from exc + except RequestException as exc: + raise AgentMemoryHttpError( + f"Request failed: {method} {path} — {exc}" + ) from exc + + if response.status_code == 204 or not response.content: + return {} + + if response.status_code == 404: + raise AgentMemoryNotFoundError( + f"Resource not found: {method} {path}", + status_code=404, + response_text=response.text, + ) + + if not response.ok: + raise AgentMemoryHttpError( + f"Agent Memory service request failed. " + f"Method: {method}, Path: {path}, " + f"Status: {response.status_code}, Response: {response.text}", + status_code=response.status_code, + response_text=response.text, + ) + + return response.json() diff --git a/src/sap_cloud_sdk/agent_memory/_models.py b/src/sap_cloud_sdk/agent_memory/_models.py new file mode 100644 index 0000000..ed41a29 --- /dev/null +++ b/src/sap_cloud_sdk/agent_memory/_models.py @@ -0,0 +1,252 @@ +"""Data models for the Agent Memory service (v1 API). + +Each model exposes a ``from_dict()`` class method that maps the API response +payload to a typed Python object. + +When migrating to a new API version, only the ``from_dict()`` methods and field +definitions in this file need to be updated. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from enum import Enum +from typing import Any, Optional + + +class MessageRole(str, Enum): + """Role of the message author.""" + + USER = "USER" + ASSISTANT = "ASSISTANT" + SYSTEM = "SYSTEM" + TOOL = "TOOL" + + +def _parse_metadata(raw: Any) -> Optional[dict[str, Any]]: + if raw is None: + return None + if isinstance(raw, dict): + return raw + if isinstance(raw, str): + try: + return json.loads(raw) + except (json.JSONDecodeError, TypeError): + return {"raw": raw} + return None + + +@dataclass +class Memory: + """Represents a memory entry with automatic vector embeddings. + + Attributes: + id: Unique memory identifier (UUID). + agent_id: Identifier of the agent that owns this memory. + invoker_id: Identifier of the user or invoker. + content: The memory text content. + metadata: Optional metadata dict (Map type in OData). + create_timestamp: ISO-8601 creation timestamp (read-only, set by server). + update_timestamp: ISO-8601 last-update timestamp (read-only, set by server). + """ + + id: str + agent_id: str + invoker_id: str + content: str + metadata: Optional[dict[str, Any]] = None + create_timestamp: Optional[str] = None + update_timestamp: Optional[str] = None + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Memory: + """Create a ``Memory`` from an API response dictionary.""" + return cls( + id=data.get("id", ""), + agent_id=data.get("agentID", ""), + invoker_id=data.get("invokerID", ""), + content=data.get("content", ""), + metadata=_parse_metadata(data.get("metadata")), + create_timestamp=data.get("createTimestamp"), + update_timestamp=data.get("updateTimestamp"), + ) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + d: dict[str, Any] = { + "id": self.id, + "agentID": self.agent_id, + "invokerID": self.invoker_id, + "content": self.content, + "createTimestamp": self.create_timestamp, + "updateTimestamp": self.update_timestamp, + } + if self.metadata is not None: + d["metadata"] = self.metadata + return d + + +@dataclass +class SearchResult: + """Represents a memory search result with similarity scores. + + Returned by the ``search_memories`` operation. + + Attributes: + id: Unique memory identifier (UUID). + agent_id: Identifier of the agent that owns this memory. + invoker_id: Identifier of the user or invoker. + content: The memory text content. + similarity: Cosine similarity score (0.0–1.0). + metadata: Optional metadata dict. + create_timestamp: ISO-8601 creation timestamp. + update_timestamp: ISO-8601 last-update timestamp. + """ + + id: str + agent_id: str + invoker_id: str + content: str + similarity: Optional[float] = None + metadata: Optional[dict[str, Any]] = None + create_timestamp: Optional[str] = None + update_timestamp: Optional[str] = None + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> SearchResult: + """Create a ``SearchResult`` from an API response dictionary.""" + return cls( + id=data.get("id", ""), + agent_id=data.get("agentID", ""), + invoker_id=data.get("invokerID", ""), + content=data.get("content", ""), + similarity=data.get("similarity"), + metadata=_parse_metadata(data.get("metadata")), + create_timestamp=data.get("createTimestamp"), + update_timestamp=data.get("updateTimestamp"), + ) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + d: dict[str, Any] = { + "id": self.id, + "agentID": self.agent_id, + "invokerID": self.invoker_id, + "content": self.content, + "similarity": self.similarity, + "createTimestamp": self.create_timestamp, + "updateTimestamp": self.update_timestamp, + } + if self.metadata is not None: + d["metadata"] = self.metadata + return d + + +@dataclass +class Message: + """Represents a chat message in the Agent Memory system. + + Messages belonging to the same logical conversation are grouped + via the ``message_group`` field. + + Attributes: + id: Unique message identifier (UUID, read-only). + agent_id: Identifier of the agent. + invoker_id: Identifier of the user or invoker. + message_group: Group identifier for conversation threading. + role: Author role (USER, ASSISTANT, SYSTEM, TOOL). Nullable per API spec. + content: The message text content. + metadata: Optional metadata dict. + create_timestamp: ISO-8601 creation timestamp (read-only, set by server). + """ + + id: str + agent_id: str + invoker_id: str + message_group: str + content: str + role: Optional[MessageRole] = None + metadata: Optional[dict[str, Any]] = None + create_timestamp: Optional[str] = None + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Message: + """Create a ``Message`` from an API response dictionary.""" + raw_role = data.get("role") + return cls( + id=data.get("id", ""), + agent_id=data.get("agentID", ""), + invoker_id=data.get("invokerID", ""), + message_group=data.get("messageGroup", ""), + content=data.get("content", ""), + role=MessageRole(raw_role) if raw_role else None, + metadata=_parse_metadata(data.get("metadata")), + create_timestamp=data.get("createTimestamp"), + ) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + d: dict[str, Any] = { + "id": self.id, + "agentID": self.agent_id, + "invokerID": self.invoker_id, + "messageGroup": self.message_group, + "content": self.content, + "createTimestamp": self.create_timestamp, + } + if self.role is not None: + d["role"] = self.role + if self.metadata is not None: + d["metadata"] = self.metadata + return d + + +# ── Admin models ────────────────────────────────────────────────────────────── + + +@dataclass +class RetentionConfig: + """Represents the data retention configuration (singleton). + + Manages data retention policies across different data categories. + Set a field to ``None`` to disable automatic cleanup for that category. + + Attributes: + id: Config identifier (integer, read-only, set by server). + message_days: How long to keep messages (days). ``None`` disables cleanup. + memory_days: How long to keep memories without access (days). ``None`` disables cleanup. + usage_log_days: How long to keep access and search logs (days). ``None`` disables cleanup. + create_timestamp: ISO-8601 creation timestamp (read-only). + update_timestamp: ISO-8601 last-update timestamp (read-only). + """ + + id: Optional[int] + message_days: Optional[int] = None + memory_days: Optional[int] = None + usage_log_days: Optional[int] = None + create_timestamp: Optional[str] = None + update_timestamp: Optional[str] = None + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> RetentionConfig: + """Create a ``RetentionConfig`` from an API response dictionary.""" + return cls( + id=data.get("id"), + message_days=data.get("messageDays"), + memory_days=data.get("memoryDays"), + usage_log_days=data.get("usageLogDays"), + create_timestamp=data.get("createTimestamp"), + update_timestamp=data.get("updateTimestamp"), + ) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "id": self.id, + "messageDays": self.message_days, + "memoryDays": self.memory_days, + "usageLogDays": self.usage_log_days, + "createTimestamp": self.create_timestamp, + "updateTimestamp": self.update_timestamp, + } diff --git a/src/sap_cloud_sdk/agent_memory/client.py b/src/sap_cloud_sdk/agent_memory/client.py new file mode 100644 index 0000000..f12030f --- /dev/null +++ b/src/sap_cloud_sdk/agent_memory/client.py @@ -0,0 +1,516 @@ +"""Client for the SAP Agent Memory service (v1 API). + +Provides memory management (CRUD + semantic search) and message operations +over a synchronous HTTP interface. All endpoint paths are defined in +``_endpoints.py``, making it straightforward to migrate to a new API version. + +Do not instantiate this class directly — use :func:`sap_cloud_sdk.agent_memory.create_client`. +""" + +from __future__ import annotations + +from typing import Any, Optional + +from sap_cloud_sdk.agent_memory._endpoints import ( + MEMORIES, + MEMORY_SEARCH, + MESSAGES, + RETENTION_CONFIG, +) +from sap_cloud_sdk.agent_memory._http_transport import HttpTransport +from sap_cloud_sdk.agent_memory._models import ( + Memory, + Message, + MessageRole, + RetentionConfig, + SearchResult, +) +from sap_cloud_sdk.agent_memory.utils._odata import ( + FilterDefinition, + build_list_params, + build_memory_filter, + build_message_filter, + extract_value_and_count, +) +from sap_cloud_sdk.agent_memory.exceptions import AgentMemoryValidationError +from sap_cloud_sdk.core.telemetry import Module, Operation, record_metrics + + +def _require_non_empty(**fields: str) -> None: + """Raise AgentMemoryValidationError if any named field is an empty string.""" + empty = [name for name, value in fields.items() if not value] + if empty: + names = ", ".join(f"'{n}'" for n in empty) + raise AgentMemoryValidationError( + f"Required field(s) must be non-empty: {names}" + ) + + +def _validate_filter_clauses( + clauses: list[FilterDefinition], allowed_targets: set[str] +) -> None: + """Raise AgentMemoryValidationError if any FilterDefinition is invalid.""" + allowed_str = ", ".join(f'"{t}"' for t in sorted(allowed_targets)) + for clause in clauses: + if clause.target not in allowed_targets: + raise AgentMemoryValidationError( + f"FilterDefinition 'target' must be one of {{{allowed_str}}}, " + f'got "{clause.target}"' + ) + if not clause.contains: + raise AgentMemoryValidationError( + "FilterDefinition 'contains' must be a non-empty string" + ) + + +class AgentMemoryClient: + """Client for the SAP Agent Memory service (v1 API). + + Provides memory CRUD, semantic search, and message management. + + Do not instantiate directly — use :func:`sap_cloud_sdk.agent_memory.create_client`. + + Args: + transport: HTTP transport layer (injected by ``create_client``). + """ + + def __init__(self, transport: HttpTransport) -> None: + self._transport = transport + + def close(self) -> None: + """Close the underlying HTTP session and release resources.""" + self._transport.close() + + def __enter__(self) -> AgentMemoryClient: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.close() + + # ── Memory operations ────────────────────────────────────────────────────── + + @record_metrics(Module.AGENT_MEMORY, Operation.AGENT_MEMORY_ADD_MEMORY) + def add_memory( + self, + agent_id: str, + invoker_id: str, + content: str, + *, + metadata: Optional[dict[str, Any]] = None, + ) -> Memory: + """Create a new memory entry. + + Args: + agent_id: Identifier of the agent. + invoker_id: Identifier of the user or invoker. + content: The memory text content. + metadata: Optional metadata dict (Map type in OData). + + Returns: + The created :class:`Memory`. + + Raises: + AgentMemoryValidationError: If any required field is empty. + AgentMemoryHttpError: If the request fails. + """ + _require_non_empty(agent_id=agent_id, invoker_id=invoker_id, content=content) + payload: dict[str, Any] = { + "agentID": agent_id, + "invokerID": invoker_id, + "content": content, + } + if metadata is not None: + payload["metadata"] = metadata + data = self._transport.post(MEMORIES, json=payload) + return Memory.from_dict(data) + + @record_metrics(Module.AGENT_MEMORY, Operation.AGENT_MEMORY_GET_MEMORY) + def get_memory(self, memory_id: str) -> Memory: + """Retrieve a memory by ID. + + Args: + memory_id: The memory identifier (UUID). + + Returns: + The :class:`Memory`. + + Raises: + AgentMemoryNotFoundError: If no memory with the given ID exists. + AgentMemoryValidationError: If ``memory_id`` is empty. + AgentMemoryHttpError: If the request fails. + """ + _require_non_empty(memory_id=memory_id) + data = self._transport.get(f"{MEMORIES}({memory_id})") + return Memory.from_dict(data) + + @record_metrics(Module.AGENT_MEMORY, Operation.AGENT_MEMORY_UPDATE_MEMORY) + def update_memory( + self, + memory_id: str, + *, + content: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, + ) -> None: + """Update a memory's content and/or metadata. + + Args: + memory_id: The memory identifier (UUID). + content: New content to set. + metadata: New metadata dict to set. + + Raises: + AgentMemoryNotFoundError: If no memory with the given ID exists. + AgentMemoryValidationError: If ``memory_id`` is empty or no fields are provided. + AgentMemoryHttpError: If the request fails. + """ + _require_non_empty(memory_id=memory_id) + if content is None and metadata is None: + raise AgentMemoryValidationError( + "At least one of 'content' or 'metadata' must be provided" + ) + payload: dict[str, Any] = {} + if content is not None: + payload["content"] = content + if metadata is not None: + payload["metadata"] = metadata + self._transport.patch(f"{MEMORIES}({memory_id})", json=payload) + + @record_metrics(Module.AGENT_MEMORY, Operation.AGENT_MEMORY_DELETE_MEMORY) + def delete_memory(self, memory_id: str) -> None: + """Delete a memory permanently. + + Args: + memory_id: The memory identifier (UUID). + + Raises: + AgentMemoryNotFoundError: If no memory with the given ID exists. + AgentMemoryValidationError: If ``memory_id`` is empty. + AgentMemoryHttpError: If the request fails. + """ + _require_non_empty(memory_id=memory_id) + self._transport.delete(f"{MEMORIES}({memory_id})") + + @record_metrics(Module.AGENT_MEMORY, Operation.AGENT_MEMORY_LIST_MEMORIES) + def list_memories( + self, + agent_id: Optional[str] = None, + invoker_id: Optional[str] = None, + *, + filters: Optional[list[FilterDefinition]] = None, + limit: int = 50, + offset: int = 0, + ) -> list[Memory]: + """List memories, optionally filtered by agent and/or invoker. + + Args: + agent_id: Filter by agent identifier. + invoker_id: Filter by invoker/user identifier. + filters: Additional substring filters. Each :class:`FilterDefinition` + specifies a ``target`` field (``"metadata"`` or ``"content"``) + and a required ``contains`` substring. Multiple clauses are + combined with AND. Metadata filtering is free-text only — + key-value structured search is not supported. + limit: Maximum number of memories to return. Default is ``50``. + offset: Number of memories to skip (for pagination). Default is ``0``. + + Returns: + List of :class:`Memory` objects. + + Raises: + AgentMemoryValidationError: If ``limit`` < 1, ``offset`` < 0, or a + filter clause is invalid. + AgentMemoryHttpError: If the request fails. + """ + if limit < 1: + raise AgentMemoryValidationError("'limit' must be >= 1") + if offset < 0: + raise AgentMemoryValidationError("'offset' must be >= 0") + if filters is not None: + _validate_filter_clauses(filters, {"metadata", "content"}) + params = build_list_params( + filter_expr=build_memory_filter( + agent_id=agent_id, + invoker_id=invoker_id, + filter_clauses=filters, + ), + top=limit, + skip=offset if offset else None, + ) + response = self._transport.get(MEMORIES, params=params) + items, _ = extract_value_and_count(response) + return [Memory.from_dict(item) for item in items] + + @record_metrics(Module.AGENT_MEMORY, Operation.AGENT_MEMORY_COUNT_MEMORIES) + def count_memories( + self, + agent_id: Optional[str] = None, + invoker_id: Optional[str] = None, + ) -> int: + """Count memories matching the given filters. + + Args: + agent_id: Filter by agent identifier. + invoker_id: Filter by invoker/user identifier. + + Returns: + Total number of matching memories. + + Raises: + AgentMemoryHttpError: If the request fails. + """ + params = build_list_params( + filter_expr=build_memory_filter(agent_id=agent_id, invoker_id=invoker_id), + top=0, + count=True, + ) + response = self._transport.get(MEMORIES, params=params) + _, total = extract_value_and_count(response) + return total or 0 + + @record_metrics(Module.AGENT_MEMORY, Operation.AGENT_MEMORY_SEARCH_MEMORIES) + def search_memories( + self, + agent_id: str, + invoker_id: str, + query: str, + threshold: float = 0.6, + limit: int = 10, + ) -> list[SearchResult]: + """Perform a semantic (vector) search over stored memories. + + Args: + agent_id: Agent identifier to scope the search. + invoker_id: Invoker/user identifier to scope the search. + query: Natural-language search query (5–5000 characters). + threshold: Minimum cosine similarity score (0.0–1.0). Default ``0.6``. + limit: Maximum number of results (1–50). Default is ``10``. + + Returns: + List of :class:`SearchResult` objects. + + Raises: + AgentMemoryValidationError: If required fields are empty or parameters are + out of range (``query`` must be 5–5000 chars, ``threshold`` 0.0–1.0, + ``limit`` 1–50). + AgentMemoryHttpError: If the request fails. + """ + _require_non_empty(agent_id=agent_id, invoker_id=invoker_id, query=query) + if not (5 <= len(query) <= 5000): + raise AgentMemoryValidationError( + "'query' must be between 5 and 5000 characters" + ) + if not (0.0 <= threshold <= 1.0): + raise AgentMemoryValidationError("'threshold' must be between 0.0 and 1.0") + if not (1 <= limit <= 50): + raise AgentMemoryValidationError("'limit' must be between 1 and 50") + payload: dict[str, Any] = { + "agentID": agent_id, + "invokerID": invoker_id, + "query": query, + "threshold": threshold, + "top": limit, + } + response = self._transport.post(MEMORY_SEARCH, json=payload) + items = response.get("value", []) + return [SearchResult.from_dict(item) for item in items] + + # ── Message operations ───────────────────────────────────────────────────── + + @record_metrics(Module.AGENT_MEMORY, Operation.AGENT_MEMORY_ADD_MESSAGE) + def add_message( + self, + agent_id: str, + invoker_id: str, + message_group: str, + role: MessageRole, + content: str, + *, + metadata: Optional[dict[str, Any]] = None, + ) -> Message: + """Create a new message. + + One message is stored per call. Messages sharing the same + ``message_group`` form a logical conversation. + + Args: + agent_id: Identifier of the agent. + invoker_id: Identifier of the user or invoker. + message_group: Group identifier for conversation threading. + role: Author role (USER, ASSISTANT, SYSTEM, TOOL). + content: The message text content. + metadata: Optional metadata dict. + + Returns: + The created :class:`Message`. + + Raises: + AgentMemoryValidationError: If any required field is empty. + AgentMemoryHttpError: If the request fails. + """ + _require_non_empty( + agent_id=agent_id, + invoker_id=invoker_id, + message_group=message_group, + content=content, + ) + payload: dict[str, Any] = { + "agentID": agent_id, + "invokerID": invoker_id, + "messageGroup": message_group, + "role": role, + "content": content, + } + if metadata is not None: + payload["metadata"] = metadata + data = self._transport.post(MESSAGES, json=payload) + return Message.from_dict(data) + + @record_metrics(Module.AGENT_MEMORY, Operation.AGENT_MEMORY_GET_MESSAGE) + def get_message(self, message_id: str) -> Message: + """Retrieve a message by ID. + + Args: + message_id: The message identifier (UUID). + + Returns: + The :class:`Message`. + + Raises: + AgentMemoryNotFoundError: If no message with the given ID exists. + AgentMemoryValidationError: If ``message_id`` is empty. + AgentMemoryHttpError: If the request fails. + """ + _require_non_empty(message_id=message_id) + data = self._transport.get(f"{MESSAGES}({message_id})") + return Message.from_dict(data) + + @record_metrics(Module.AGENT_MEMORY, Operation.AGENT_MEMORY_DELETE_MESSAGE) + def delete_message(self, message_id: str) -> None: + """Delete a message permanently. + + Args: + message_id: The message identifier (UUID). + + Raises: + AgentMemoryNotFoundError: If no message with the given ID exists. + AgentMemoryValidationError: If ``message_id`` is empty. + AgentMemoryHttpError: If the request fails. + """ + _require_non_empty(message_id=message_id) + self._transport.delete(f"{MESSAGES}({message_id})") + + @record_metrics(Module.AGENT_MEMORY, Operation.AGENT_MEMORY_LIST_MESSAGES) + def list_messages( + self, + agent_id: Optional[str] = None, + invoker_id: Optional[str] = None, + message_group: Optional[str] = None, + role: Optional[str] = None, + *, + filters: Optional[list[FilterDefinition]] = None, + limit: int = 50, + offset: int = 0, + ) -> list[Message]: + """List messages, optionally filtered by agent, invoker, group, and role. + + Args: + agent_id: Filter by agent identifier. + invoker_id: Filter by invoker/user identifier. + message_group: Filter by conversation group identifier. + role: Filter by author role (USER, ASSISTANT, SYSTEM, TOOL). + filters: Additional substring filters. Each :class:`FilterDefinition` + specifies a ``target`` field (``"metadata"`` or ``"content"``) + and a required ``contains`` substring. Multiple clauses are + combined with AND. Metadata filtering is free-text only — + key-value structured search is not supported. + limit: Maximum number of messages to return. Default is ``50``. + offset: Number of messages to skip (for pagination). Default is ``0``. + + Returns: + List of :class:`Message` objects. + + Raises: + AgentMemoryValidationError: If ``limit`` < 1, ``offset`` < 0, or a + filter clause is invalid. + AgentMemoryHttpError: If the request fails. + """ + if limit < 1: + raise AgentMemoryValidationError("'limit' must be >= 1") + if offset < 0: + raise AgentMemoryValidationError("'offset' must be >= 0") + if filters is not None: + _validate_filter_clauses(filters, {"metadata", "content"}) + params = build_list_params( + filter_expr=build_message_filter( + agent_id=agent_id, + invoker_id=invoker_id, + message_group=message_group, + role=role, + filter_clauses=filters, + ), + top=limit, + skip=offset if offset else None, + ) + response = self._transport.get(MESSAGES, params=params) + items, _ = extract_value_and_count(response) + return [Message.from_dict(item) for item in items] + + # ── Admin operations ─────────────────────────────────────────────────────── + + @record_metrics(Module.AGENT_MEMORY, Operation.AGENT_MEMORY_GET_RETENTION_CONFIG) + def get_retention_config(self) -> RetentionConfig: + """Retrieve the data retention configuration (singleton). + + Returns: + The current :class:`RetentionConfig`. + + Raises: + AgentMemoryHttpError: If the request fails. + """ + data = self._transport.get(RETENTION_CONFIG) + return RetentionConfig.from_dict(data) + + @record_metrics(Module.AGENT_MEMORY, Operation.AGENT_MEMORY_UPDATE_RETENTION_CONFIG) + def update_retention_config( + self, + *, + message_days: Optional[int] = None, + memory_days: Optional[int] = None, + usage_log_days: Optional[int] = None, + ) -> None: + """Update the data retention configuration. + + Only the provided fields are updated. Set a field to ``0`` to + explicitly pass zero, or omit it to leave unchanged. + The server accepts ``null`` to disable cleanup for a category. + + Args: + message_days: How long to keep messages (days). + memory_days: How long to keep memories without access (days). + usage_log_days: How long to keep access and search logs (days). + + Raises: + AgentMemoryValidationError: If no fields are provided, or any provided + value is negative. + AgentMemoryHttpError: If the request fails. + """ + if message_days is None and memory_days is None and usage_log_days is None: + raise AgentMemoryValidationError( + "At least one of 'message_days', 'memory_days', or " + "'usage_log_days' must be provided" + ) + for name, value in ( + ("message_days", message_days), + ("memory_days", memory_days), + ("usage_log_days", usage_log_days), + ): + if value is not None and value < 0: + raise AgentMemoryValidationError(f"'{name}' must be >= 0") + payload: dict[str, Any] = {} + if message_days is not None: + payload["messageDays"] = message_days + if memory_days is not None: + payload["memoryDays"] = memory_days + if usage_log_days is not None: + payload["usageLogDays"] = usage_log_days + self._transport.patch(RETENTION_CONFIG, json=payload) diff --git a/src/sap_cloud_sdk/agent_memory/config.py b/src/sap_cloud_sdk/agent_memory/config.py new file mode 100644 index 0000000..fbe4f17 --- /dev/null +++ b/src/sap_cloud_sdk/agent_memory/config.py @@ -0,0 +1,173 @@ +"""Configuration and secret resolution for the Agent Memory service. + +Loads service binding secrets from a mounted volume with environment fallback, +then normalises into an ``AgentMemoryConfig``. + +Mount path convention:: + + + + /etc/secrets/appfnd/hana-agent-memory/default/{field_key} + +Keys: ``application_url``, ``uaa.url``, ``uaa.clientid``, ``uaa.clientsecret`` + +Env fallback convention (uppercased):: + + CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_APPLICATION_URL + CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA_URL + CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA_CLIENTID + CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA_CLIENTSECRET +""" + +from dataclasses import dataclass, field +from typing import Optional + +from sap_cloud_sdk.agent_memory.exceptions import AgentMemoryConfigError + + +@dataclass +class AgentMemoryConfig: + """Configuration for the Agent Memory service. + + Attributes: + base_url: The base URL of the Agent Memory service. + token_url: The OAuth2 token endpoint URL. Optional — if not provided, + requests are sent without authentication (useful for local development). + client_id: The OAuth2 client ID. Optional. + client_secret: The OAuth2 client secret. Optional. + timeout: Timeout in seconds for HTTP requests. Default is 30.0. + + Example — deployed BTP service:: + + config = AgentMemoryConfig( + base_url="https://", + token_url="https://.authentication./oauth/token", + client_id="", + client_secret="", + ) + + Example — local development (no auth):: + + config = AgentMemoryConfig(base_url="http://localhost:3000") + """ + + base_url: str + token_url: Optional[str] = None + client_id: Optional[str] = None + client_secret: Optional[str] = None + timeout: float = 30.0 + + def __post_init__(self) -> None: + if not self.base_url: + raise AgentMemoryConfigError("base_url must be a non-empty string") + + +# NOTE: BindingData must NOT use `from __future__ import annotations` +# because the secret resolver checks `f.type is str` at runtime, which requires +# actual type objects rather than string annotations. + + +@dataclass +class BindingData: + """Raw binding secrets read by the secret resolver. + + All fields must be plain ``str`` to satisfy the resolver contract. + """ + + application_url: str = "" + uaa_url: str = field(default="", metadata={"secret": "uaa.url"}) + uaa_clientid: str = field(default="", metadata={"secret": "uaa.clientid"}) + uaa_clientsecret: str = field(default="", metadata={"secret": "uaa.clientsecret"}) + + def validate(self) -> None: + """Raise ``AgentMemoryConfigError`` if any required field is empty.""" + missing = [ + f + for f in ("application_url", "uaa_url", "uaa_clientid", "uaa_clientsecret") + if not getattr(self, f) + ] + if missing: + raise AgentMemoryConfigError( + f"Agent Memory binding is missing required fields: {', '.join(missing)}" + ) + + def extract_config(self) -> AgentMemoryConfig: + """Derive an ``AgentMemoryConfig`` from the raw binding fields.""" + return AgentMemoryConfig( + base_url=self.application_url, + token_url=self.uaa_url.rstrip("/") + "/oauth/token", + client_id=self.uaa_clientid, + client_secret=self.uaa_clientsecret, + ) + + +_ENV_PREFIX = "CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT" + +# Explicit env var names — dots are not valid in shell variable names, +# so we define these directly rather than deriving them from BindingData metadata keys +# (which use dots to match the BTP mount-path file naming convention). +_ENV_VARS = { + "application_url": f"{_ENV_PREFIX}_APPLICATION_URL", + "uaa_url": f"{_ENV_PREFIX}_UAA_URL", + "uaa_clientid": f"{_ENV_PREFIX}_UAA_CLIENTID", + "uaa_clientsecret": f"{_ENV_PREFIX}_UAA_CLIENTSECRET", +} + + +def _load_binding_from_env() -> BindingData: + """Read Agent Memory binding from environment variables. + + Raises: + AgentMemoryConfigError: If any required variable is absent. + """ + import os + + binding = BindingData() + missing: list[str] = [] + for attr, var in _ENV_VARS.items(): + value = os.environ.get(var) + if not value: + missing.append(var) + else: + setattr(binding, attr, value) + if missing: + raise AgentMemoryConfigError( + f"Missing required environment variables: {', '.join(missing)}" + ) + return binding + + +def _load_config_from_env() -> AgentMemoryConfig: + """Load Agent Memory configuration from a mounted volume or environment variables. + + Tries (in order): + 1. Mount at ``/etc/secrets/appfnd/hana-agent-memory/default/`` + 2. Environment variables ``CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_*`` + + Returns: + A validated ``AgentMemoryConfig``. + + Raises: + AgentMemoryConfigError: If configuration cannot be loaded or is incomplete. + """ + from sap_cloud_sdk.core.secret_resolver.resolver import _load_from_mount + + mount_error: Exception | None = None + try: + binding = BindingData() + _load_from_mount("/etc/secrets/appfnd", "hana-agent-memory", "default", binding) + binding.validate() + return binding.extract_config() + except Exception as exc: + mount_error = exc + + try: + binding = _load_binding_from_env() + binding.validate() + return binding.extract_config() + except AgentMemoryConfigError: + raise + except Exception as exc: + raise AgentMemoryConfigError( + f"Failed to load Agent Memory configuration: mount={mount_error}; env={exc}" + ) from exc diff --git a/src/sap_cloud_sdk/agent_memory/exceptions.py b/src/sap_cloud_sdk/agent_memory/exceptions.py new file mode 100644 index 0000000..e0ea2f4 --- /dev/null +++ b/src/sap_cloud_sdk/agent_memory/exceptions.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Optional + + +class AgentMemoryError(Exception): + """Base exception for all Agent Memory service errors.""" + + +class AgentMemoryConfigError(AgentMemoryError): + """Raised for configuration errors (missing env vars, bad VCAP_SERVICES, empty base_url).""" + + +class AgentMemoryHttpError(AgentMemoryError): + """Raised for HTTP and network failures. + + Attributes: + status_code: HTTP status code, or None for network-level failures. + response_text: Raw response body text, or None if unavailable. + """ + + def __init__( + self, + message: str, + *, + status_code: Optional[int] = None, + response_text: Optional[str] = None, + ) -> None: + super().__init__(message) + self.status_code = status_code + self.response_text = response_text + + +class AgentMemoryNotFoundError(AgentMemoryHttpError): + """Raised when the Agent Memory service returns 404 Not Found.""" + + +class AgentMemoryValidationError(AgentMemoryError): + """Raised when client-side input validation fails before a request is sent.""" diff --git a/src/sap_cloud_sdk/agent_memory/py.typed b/src/sap_cloud_sdk/agent_memory/py.typed new file mode 100644 index 0000000..7f8a6b0 --- /dev/null +++ b/src/sap_cloud_sdk/agent_memory/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561 to indicate the 'agent_memory' package is typed. \ No newline at end of file diff --git a/src/sap_cloud_sdk/agent_memory/user-guide.md b/src/sap_cloud_sdk/agent_memory/user-guide.md new file mode 100644 index 0000000..338fa97 --- /dev/null +++ b/src/sap_cloud_sdk/agent_memory/user-guide.md @@ -0,0 +1,785 @@ +# Agent Memory User Guide + +This module provides a Python client for the SAP Agent Memory service (v1 API). It lets agents +store, retrieve, and semantically search persistent memories, and record conversation messages +grouped into logical message groups. The service handles vector embeddings automatically for memories — you store +plain text, and the service makes it searchable by meaning. + +> [!NOTE] +> Memory extraction is the caller's responsibility. This client stores whatever text you pass +> as `content`; it does not extract or summarize memories from conversation text on its own. + +## Table of Contents + +- [Agent Memory User Guide](#agent-memory-user-guide) + - [Table of Contents](#table-of-contents) + - [Installation](#installation) + - [Import](#import) + - [Quick Start](#quick-start) + - [Basic Setup](#basic-setup) + - [Custom Configuration](#custom-configuration) + - [Using the Context Manager](#using-the-context-manager) + - [Core Concepts](#core-concepts) + - [`agent_id`](#agent_id) + - [`invoker_id`](#invoker_id) + - [Semantic Search: A Brief Primer](#semantic-search-a-brief-primer) + - [Memories](#memories) + - [Create a Memory](#create-a-memory) + - [Get a Memory](#get-a-memory) + - [Update a Memory](#update-a-memory) + - [Delete a Memory](#delete-a-memory) + - [List Memories](#list-memories) + - [Content and metadata filtering](#content-and-metadata-filtering) + - [Count Memories](#count-memories) + - [Semantic Search](#semantic-search) + - [Messages](#messages) + - [Create a Message](#create-a-message) + - [Get a Message](#get-a-message) + - [Delete a Message](#delete-a-message) + - [List Messages](#list-messages) + - [Content and metadata filtering](#content-and-metadata-filtering-1) + - [Data Models](#data-models) + - [Enums](#enums) + - [Error Handling](#error-handling) + - [Admin — Retention Config](#admin--retention-config) + - [Get Retention Config](#get-retention-config) + - [Update Retention Config](#update-retention-config) + - [Common Scenarios](#common-scenarios) + - [Injecting relevant memories into an LLM prompt](#injecting-relevant-memories-into-an-llm-prompt) + - [Persisting a conversation turn](#persisting-a-conversation-turn) + - [Retrieving a full conversation thread](#retrieving-a-full-conversation-thread) + - [Paginating through all memories](#paginating-through-all-memories) + - [Paginating through all messages](#paginating-through-all-messages) + - [Troubleshooting](#troubleshooting) + - [`AgentMemoryConfigError` on startup](#agentmemoryconfigerror-on-startup) + - [`list_memories()` or `list_messages()` returns fewer results than expected](#list_memories-or-list_messages-returns-fewer-results-than-expected) + - [`search_memories()` returns no results](#search_memories-returns-no-results) + - [`AgentMemoryNotFoundError` when fetching a resource](#agentmemorynotfounderror-when-fetching-a-resource) + - [`AgentMemoryHttpError` with status 401](#agentmemoryhttperror-with-status-401) + - [Configuration](#configuration) + +## Installation + +See the [SAP Cloud SDK for Python installation guide](https://github.com/SAP/cloud-sdk-python#installation) +for setup instructions. The agent memory module is included automatically. + +## Import + +You can import specific classes: + +```python +from sap_cloud_sdk.agent_memory import ( + create_client, + AgentMemoryConfig, + FilterDefinition, + Memory, + Message, + MessageRole, + RetentionConfig, + SearchResult, +) +``` + +Or use a star import for convenience: + +```python +from sap_cloud_sdk.agent_memory import * +``` + +## Quick Start + +### Basic Setup + +Use `create_client()` to get a client with automatic credential detection: + +```python +from sap_cloud_sdk.agent_memory import create_client + +client = create_client() + +memories = client.list_memories(agent_id="my-agent", invoker_id="user-123") +print(f"Found {len(memories)} memories") +``` + +`create_client()` reads credentials from the `CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_*` +environment variables (or a mounted volume on BTP). See the +[Configuration](#configuration) section for the full variable table. + +### Custom Configuration + +There's also support for custom configuration if you want to specify credentials directly: + +```python +from sap_cloud_sdk.agent_memory import create_client, AgentMemoryConfig + +config = AgentMemoryConfig( + base_url="https://", + token_url="https://.authentication./oauth/token", + client_id="", + client_secret="", +) +client = create_client(config=config) +``` + +### Using the Context Manager + +The context manager is optional, but it is the easiest way to ensure the client is +closed even if an exception is raised: + +```python +with create_client() as client: + memories = client.list_memories(agent_id="my-agent", invoker_id="user-123") +``` + +To close the client manually, call `client.close()`. + +`close()` is only for local cleanup. It does **not** commit, flush, or roll back data. +Each API call is independent and final once accepted by the service. + +Calling methods after `close()` is supported. + +## Core Concepts + +### `agent_id` + +A stable identifier for the agent that owns the data — for example `"hr-assistant"` or +`"support-bot"`. Chosen by the implementer; typically the name or ID of the AI agent. + +### `invoker_id` + +Identifies the user or caller associated with the data — for example a user ID from +the application's auth system. Memories and messages are scoped to the combination of +`agent_id` and `invoker_id`. + +Neither value is validated by the service — they are free-form strings. Consistent use +across create, read, and search calls is the implementer's responsibility. + +## Semantic Search: A Brief Primer + +Texts with different words — or even different languages — can have the same meaning. +"How to make pizza dough?" and "Italian flatbread preparation steps" are semantically similar +despite sharing no words. To search a large corpus by meaning rather than exact keywords, the +service uses vector embeddings. + +An embedding model translates a text into a high-dimensional numeric vector. Texts with similar +meaning produce vectors that point in a similar direction. The cosine similarity between two +vectors measures that directional closeness: a value near 1.0 means the texts are semantically +similar. + +Example corpus: + +- "Trains cross bridges" +- "Clouds block sunlight" +- "Rivers carve valleys" +- "Wolves hunt deer" +- "Engines power ships" + +A search for "Sky illumination" returns "Clouds block sunlight" — closest in meaning, with the +highest cosine similarity — even though the query shares no words with the result. + +`search_memories()` uses this mechanism: you pass a natural-language query and a similarity +threshold, and the service returns the most semantically relevant stored memories. + +## Memories + +Memories are persistent knowledge entries scoped to an `agent_id` + `invoker_id` pair. +The service generates a vector embedding for each memory automatically, enabling semantic search. + +### Create a Memory + +```python +memory = client.add_memory( + agent_id="my-agent", + invoker_id="user-123", + content="The user prefers dark mode and metric units.", + metadata={"source": "preferences"}, +) +print(memory.id) +# "a1b2c3d4-e5f6-7890-abcd-ef1234567890" +``` + +**Required fields:** + +- `agent_id`: Identifier of the agent that owns this memory. +- `invoker_id`: Identifier of the user or caller associated with this memory. +- `content`: The memory text (plain string). + +**Optional fields:** + +- `metadata`: Arbitrary key-value dict stored alongside the memory. + +### Get a Memory + +```python +memory = client.get_memory(memory_id="") +print(memory.content) +# "The user prefers dark mode and metric units." +``` + +### Update a Memory + +`update_memory` performs a partial update; +omitted fields remain untouched. + +> [!NOTE] +> `content` and `metadata` are the only editable fields; `memory_id` identifies which memory to update and cannot be modified + +```python +client.update_memory( + memory_id="", + content="user prefers dark mode, metric units, and large font.", + metadata={"source": "preferences", "version": 2}, +) +``` + +### Delete a Memory + +```python +client.delete_memory(memory_id="") +``` + +### List Memories + +```python +memories = client.list_memories( + agent_id="my-agent", + invoker_id="user-123", + limit=20, +) +for m in memories: + print(f" [{m.id}] {m.content[:80]}") +# [a1b2c3d4-...] The user prefers dark mode and metric units. +# [b2c3d4e5-...] The user's timezone is Europe/Berlin. +``` + +**Parameters:** + +| Parameter | Type | Default | Description | +| ------------ | ---------------------------------- | ------- | ------------------------------------------------- | +| `agent_id` | `str` \| `None` | `None` | Filter by agent identifier. | +| `invoker_id` | `str` \| `None` | `None` | Filter by invoker/user identifier. | +| `filters` | `list[FilterDefinition]` \| `None` | `None` | Substring filters on `"content"` or `"metadata"`. | +| `limit` | `int` | `50` | Maximum number of memories to return. | +| `offset` | `int` | `0` | Number of memories to skip (pagination). | + +**Returns:** `list[Memory]` + +#### Content and metadata filtering + +Use `FilterDefinition` to narrow results by substring. Import it alongside `create_client`: + +```python +from sap_cloud_sdk.agent_memory import create_client, FilterDefinition + +# Memories whose content contains "dark mode" +memories = client.list_memories( + agent_id="my-agent", + invoker_id="user-123", + filters=[FilterDefinition(target="content", contains="dark mode")], +) + +# Combined: content AND metadata must both match +memories = client.list_memories( + agent_id="my-agent", + invoker_id="user-123", + filters=[ + FilterDefinition(target="content", contains="dark mode"), + FilterDefinition(target="metadata", contains="preferences"), + ], +) +``` + +`target` must be `"content"` or `"metadata"`. Multiple clauses are combined with AND. + +> [!WARNING] +> Defining two clauses with the **same target** produces an AND predicate that requires +> both substrings to be present in the same field simultaneously. This is rarely +> intentional — for example: +> +> ```python +> filters=[ +> FilterDefinition(target="content", contains="user prefers"), +> FilterDefinition(target="content", contains="user doesn't prefer"), +> ] +> ``` +> +> Only memories whose content contains _both_ substrings will be returned, which is +> typically an empty result set. OR combining across clauses is not yet supported. + +> [!NOTE] +> Metadata is stored as a JSON string. Filtering on `"metadata"` performs a free-text +> substring match on the raw JSON — for example `contains="preferences"` matches any +> metadata whose serialized form includes that word. Structured key-value filtering +> (e.g. `metadata.source == "preferences"`) is not supported. + +### Count Memories + +Count memories without fetching their content. Near-zero cost. + +```python +total = client.count_memories(agent_id="my-agent", invoker_id="user-123") +print(f"Total memories: {total}") +# Total memories: 42 +``` + +**Parameters:** + +| Parameter | Type | Default | Description | +| ------------ | --------------- | ------- | ---------------------------------- | +| `agent_id` | `str` \| `None` | `None` | Filter by agent identifier. | +| `invoker_id` | `str` \| `None` | `None` | Filter by invoker/user identifier. | + +**Returns:** `int` + +### Semantic Search + +Search for memories whose meaning is similar to a natural-language query. The service returns +results ordered by relevance (highest similarity first). + +```python +results = client.search_memories( + agent_id="my-agent", + invoker_id="user-123", + query="What are the user's display preferences?", + threshold=0.6, + limit=5, +) +for r in results: + print(f"[similarity={r.similarity:.2f}] {r.content}") +# [similarity=0.92] The user prefers dark mode and metric units. +# [similarity=0.81] User last asked about display settings on 2025-01-10. +``` + +**Parameters:** + +| Parameter | Type | Default | Description | +| ------------ | ------- | ------- | -------------------------------------------------- | +| `agent_id` | `str` | — | Agent identifier to scope the search. | +| `invoker_id` | `str` | — | Invoker/user identifier to scope the search. | +| `query` | `str` | — | Natural-language search query (5–5000 characters). | +| `threshold` | `float` | `0.6` | Minimum cosine similarity score (0.0–1.0). | +| `limit` | `int` | `10` | Maximum number of results (1–50). | + +**Returns:** `list[SearchResult]` — each result extends `Memory` with a `similarity` (cosine score) field. + +--- + +## Messages + +Messages represent individual turns in a conversation. Messages sharing the same `message_group` +form a logical message group. The service does not enforce a session concept — grouping is done +entirely via the `message_group` value you choose. + +### Create a Message + +```python +from sap_cloud_sdk.agent_memory import MessageRole + +message = client.add_message( + agent_id="my-agent", + invoker_id="user-123", + message_group="conv-001", + role=MessageRole.USER, + content="What is the weather like today?", +) +print(message.id) +# "c3d4e5f6-a1b2-..." +``` + +**Required fields:** + +- `agent_id`: Identifier of the agent. +- `invoker_id`: Identifier of the user or caller. +- `message_group`: Message group identifier (any string; use a consistent value per conversation). +- `role`: Author role — use the `MessageRole` enum: `USER`, `ASSISTANT`, `SYSTEM`, `TOOL`. +- `content`: The message text. + +**Optional fields:** + +- `metadata`: Arbitrary key-value dict stored alongside the message. + +### Get a Message + +```python +message = client.get_message(message_id="") +print(f"[{message.role}] {message.content}") +# [USER] What is the weather like today? +``` + +### Delete a Message + +```python +client.delete_message(message_id="") +``` + +### List Messages + +```python +messages = client.list_messages( + agent_id="my-agent", + invoker_id="user-123", + message_group="conv-001", + limit=50, +) +for msg in messages: + print(f" [{msg.role}] {msg.content[:80]}") +# [USER] What is the weather like today? +# [ASSISTANT] It's sunny and 22°C in Berlin. +``` + +Filter by role to retrieve only a specific author's turns: + +```python +user_messages = client.list_messages( + agent_id="my-agent", + invoker_id="user-123", + message_group="conv-001", + role=MessageRole.USER, +) +``` + +**Parameters:** + +| Parameter | Type | Default | Description | +| --------------- | ---------------------------------- | ------- | ------------------------------------------------- | +| `agent_id` | `str` \| `None` | `None` | Filter by agent identifier. | +| `invoker_id` | `str` \| `None` | `None` | Filter by invoker/user identifier. | +| `message_group` | `str` \| `None` | `None` | Filter by conversation group. | +| `role` | `str` \| `None` | `None` | Filter by author role (USER, ASSISTANT, …). | +| `filters` | `list[FilterDefinition]` \| `None` | `None` | Substring filters on `"content"` or `"metadata"`. | +| `limit` | `int` | `50` | Maximum number of messages to return. | +| `offset` | `int` | `0` | Number of messages to skip (pagination). | + +**Returns:** `list[Message]` + +#### Content and metadata filtering + +The same `FilterDefinition` syntax applies to messages: + +```python +from sap_cloud_sdk.agent_memory import create_client, FilterDefinition + +# Messages whose metadata contains a specific tag +messages = client.list_messages( + agent_id="my-agent", + invoker_id="user-123", + message_group="conversation-001", + filters=[FilterDefinition(target="metadata", contains="escalated")], +) + +# Messages whose content mentions a keyword +messages = client.list_messages( + agent_id="my-agent", + invoker_id="user-123", + filters=[FilterDefinition(target="content", contains="invoice")], +) +``` + +See the [Content and metadata filtering](#content-and-metadata-filtering) note under +[List Memories](#list-memories) for details on metadata free-text limitations. + +--- + +## Data Models + +| Model | Description | +| ----------------- | ----------------------------------------------------------------------------------------------------- | +| `Memory` | A persistent memory entry (`id`, `agent_id`, `invoker_id`, `content`, `metadata`, timestamps) | +| `SearchResult` | Extends `Memory` with a `similarity` field (cosine score, 0–1) | +| `Message` | A message (`id`, `agent_id`, `invoker_id`, `message_group`, `role`, `content`, `metadata`, timestamp) | +| `RetentionConfig` | Data retention policy (`message_days`, `memory_days`, `usage_log_days`, timestamps) | + +### Enums + +| Enum | Values | +| ------------- | ------------------------------------- | +| `MessageRole` | `USER`, `ASSISTANT`, `SYSTEM`, `TOOL` | + +All models expose a `to_dict()` method that returns a plain dict for logging or forwarding. + +```python +memory = client.get_memory(memory_id="a1b2c3d4-...") +print(memory.to_dict()) +# { +# "id": "a1b2c3d4-...", +# "agent_id": "my-agent", +# "invoker_id": "user-123", +# "content": "The user prefers dark mode and metric units.", +# "metadata": {}, +# "created_at": "2025-01-10T12:00:00Z", +# "updated_at": "2025-01-10T12:00:00Z", +# } +``` + +--- + +## Error Handling + +The module defines a structured exception hierarchy so you can catch errors at the appropriate +level of specificity: + +``` +AgentMemoryError +├── AgentMemoryConfigError # bad or missing configuration +├── AgentMemoryValidationError # invalid inputs caught before any network call +└── AgentMemoryHttpError # HTTP-level error (status_code, response_text) + └── AgentMemoryNotFoundError # 404 Not Found +``` + +```python +from sap_cloud_sdk.agent_memory.exceptions import ( + AgentMemoryError, + AgentMemoryConfigError, + AgentMemoryValidationError, + AgentMemoryHttpError, + AgentMemoryNotFoundError, +) + +# Catch invalid inputs before they reach the network +try: + client.add_memory(agent_id="", invoker_id="user-123", content="hello") +except AgentMemoryValidationError as e: + print(f"Bad input: {e}") +# Bad input: Required field(s) must be non-empty: 'agent_id' + +# Catch a specific 404 +try: + memory = client.get_memory(memory_id="non-existent-id") +except AgentMemoryNotFoundError: + print("Memory not found") + +# Inspect the HTTP status code and response body +try: + memories = client.list_memories(agent_id="my-agent") +except AgentMemoryHttpError as e: + print(f"HTTP {e.status_code}: {e.response_text}") + +# Catch all Agent Memory errors +try: + client = create_client() + memories = client.list_memories(agent_id="my-agent") +except AgentMemoryError as e: + print(f"Agent Memory error: {e}") +``` + +--- + +## Admin — Retention Config + +The retention configuration controls automatic data cleanup. It is a singleton — one config +per tenant. + +### Get Retention Config + +```python +rc = client.get_retention_config() +print(f"Messages: {rc.message_days} days") +print(f"Memories: {rc.memory_days} days") +print(f"Usage logs: {rc.usage_log_days} days") +``` + +### Update Retention Config + +`update_retention_config` performs a partial update — only the provided fields are +changed; omitted fields remain unchanged. + +```python +client.update_retention_config( + message_days=30, + memory_days=90, + usage_log_days=180, +) +``` + +Set a field to `0` to mark all data in that category for deletion at the next nightly scheduled cleanup. The server also accepts `null` to disable +automatic cleanup for that category. + +**When changes take effect** + +The service runs nightly data cleanup procedures that delete records based on creation timestamp. Changes to retention configuration apply to all future retention sweeps. The new retention window is calculated from each record's original creation timestamp, not from the time of the config change. + +_Increasing retention_ — records that were approaching expiry get more time. For example, +if `message_days` is raised from 90 to 120, a message created 89 days ago will now be +retained until it reaches 120 days old rather than being cleaned up after 90 days. + +_Decreasing retention_ — records outside the new window become eligible for removal. For +example, if `message_days` is reduced from 90 to 30, messages older than 30 days will be +removed at the next retention sweep, even if they fell within the original 90-day limit +when they were created. + +> [!WARNING] +> Decreasing a retention period is a destructive, irreversible operation. Records outside +> the new window are permanently deleted at the next cleanup sweep. + +--- + +## Common Scenarios + +### Injecting relevant memories into an LLM prompt + +Retrieve the most semantically relevant past memories before calling the language model: + +```python +def build_context(client, agent_id, invoker_id, user_query): + results = client.search_memories( + agent_id=agent_id, + invoker_id=invoker_id, + query=user_query, + threshold=0.65, + limit=5, + ) + if not results: + return "" + lines = [f"- {r.content}" for r in results] + return "Relevant context from memory:\n" + "\n".join(lines) +``` + +### Persisting a conversation turn + +Store each user and assistant message so the full conversation history is available: + +```python +def record_turn(client, agent_id, invoker_id, group_id, user_text, assistant_text): + client.add_message( + agent_id=agent_id, + invoker_id=invoker_id, + message_group=group_id, + role=MessageRole.USER, + content=user_text, + ) + client.add_message( + agent_id=agent_id, + invoker_id=invoker_id, + message_group=group_id, + role=MessageRole.ASSISTANT, + content=assistant_text, + ) +``` + +### Retrieving a full conversation thread + +```python +def get_conversation(client, agent_id, invoker_id, group_id): + return client.list_messages( + agent_id=agent_id, + invoker_id=invoker_id, + message_group=group_id, + limit=100, + ) +``` + +### Paginating through all memories + +`list_memories` returns at most `limit` results per call. Use `offset` to page through large +sets, or use `count_memories` first to decide whether pagination is even necessary: + +```python +PAGE_SIZE = 100 + +total = client.count_memories(agent_id="my-agent", invoker_id="user-123") +if total == 0: + memories = [] +elif total <= PAGE_SIZE: + memories = client.list_memories( + agent_id="my-agent", invoker_id="user-123", limit=total + ) +else: + def iter_all_memories(client, agent_id, invoker_id, page_size=PAGE_SIZE): + offset = 0 + while True: + page = client.list_memories( + agent_id=agent_id, + invoker_id=invoker_id, + limit=page_size, + offset=offset, + ) + yield from page + if len(page) < page_size: + break + offset += page_size + + memories = list(iter_all_memories(client, "my-agent", "user-123")) +``` + +### Paginating through all messages + +```python +def iter_all_messages(client, agent_id, invoker_id, message_group, page_size=100): + offset = 0 + while True: + page = client.list_messages( + agent_id=agent_id, + invoker_id=invoker_id, + message_group=message_group, + limit=page_size, + offset=offset, + ) + yield from page + if len(page) < page_size: + break + offset += page_size +``` + +--- + +## Troubleshooting + +### `AgentMemoryConfigError` on startup + +``` +AgentMemoryConfigError: Failed to load configuration: ... +``` + +Credentials could not be found. Check that either: + +- The BTP service binding is mounted at `/etc/secrets/appfnd/hana-agent-memory/default/` +- Or the environment variables are set (see [Configuration](#configuration)) + +### `list_memories()` or `list_messages()` returns fewer results than expected + +The default `limit` is `50`. Increase it or paginate: + +```python +memories = client.list_memories(agent_id="my-agent", invoker_id="user-123", limit=200) +``` + +Also verify `agent_id` and `invoker_id` exactly match the values used when the memories were created. + +### `search_memories()` returns no results + +The default `threshold` of `0.6` may be too strict for your data. Try a lower value: + +```python +results = client.search_memories( + agent_id="my-agent", invoker_id="user-123", + query="user display preferences", + threshold=0.3, +) +``` + +### `AgentMemoryNotFoundError` when fetching a resource + +The resource was deleted, the ID is incorrect, or the `agent_id`/`invoker_id` passed to a +list or search operation does not match the values used when the resource was created. + +### `AgentMemoryHttpError` with status 401 + +The OAuth2 token has expired and automatic refresh failed, or the configured credentials +(`client_id`, `client_secret`, `token_url`) are incorrect. Verify the credentials in your +environment variables or service binding. + +--- + +## Configuration + +`create_client()` resolves credentials automatically in the following order: + +1. **Mounted volume** — `/etc/secrets/appfnd/hana-agent-memory/default/{field}` +2. **Environment variables** — `CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_*` + +| Environment Variable | Description | +| ---------------------------------------------------------- | ------------------------------------ | +| `CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_APPLICATION_URL` | Base URL of the Agent Memory service | +| `CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA_URL` | OAuth2 authorization server base URL | +| `CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA_CLIENTID` | OAuth2 client ID | +| `CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA_CLIENTSECRET` | OAuth2 client secret | diff --git a/src/sap_cloud_sdk/agent_memory/utils/__init__.py b/src/sap_cloud_sdk/agent_memory/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/sap_cloud_sdk/agent_memory/utils/_odata.py b/src/sap_cloud_sdk/agent_memory/utils/_odata.py new file mode 100644 index 0000000..1cb1f9c --- /dev/null +++ b/src/sap_cloud_sdk/agent_memory/utils/_odata.py @@ -0,0 +1,162 @@ +"""OData v4 query-building utilities for the Agent Memory service. + +When migrating to a non-OData API, replace the helpers in this file and update +the call sites in ``client.py``. The client method signatures stay the same. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class FilterDefinition: + """A single ``contains`` predicate used in the ``filter`` parameter. + + Args: + target: Field to filter on. Allowed values: ``"metadata"``, ``"content"``. + contains: Required substring. Must be non-empty. + """ + + target: str + contains: str + + +def _escape_odata_string(value: str) -> str: + """Escape single quotes in an OData string literal ('' per OData 4.01).""" + return value.replace("'", "''") + + +def build_contains_clauses(clauses: list[FilterDefinition]) -> list[str]: + """Convert FilterDefinition objects into OData contains() expressions.""" + return [ + f"contains({clause.target}, '{_escape_odata_string(clause.contains)}')" + for clause in clauses + ] + + +def build_memory_filter( + agent_id: Optional[str] = None, + invoker_id: Optional[str] = None, + filter_clauses: Optional[list[FilterDefinition]] = None, +) -> Optional[str]: + """Build an OData ``$filter`` expression for memories. + + Args: + agent_id: Filter by agent identifier. + invoker_id: Filter by invoker/user identifier. + filter_clauses: Additional ``contains`` predicates. + + Returns: + A ``$filter`` string, or ``None`` if no filters are requested. + """ + parts: list[str] = [] + if agent_id: + parts.append(f"agentID eq '{_escape_odata_string(agent_id)}'") + if invoker_id: + parts.append(f"invokerID eq '{_escape_odata_string(invoker_id)}'") + if filter_clauses: + parts.extend(build_contains_clauses(filter_clauses)) + return " and ".join(parts) if parts else None + + +def build_message_filter( + agent_id: Optional[str] = None, + invoker_id: Optional[str] = None, + message_group: Optional[str] = None, + role: Optional[str] = None, + filter_clauses: Optional[list[FilterDefinition]] = None, +) -> Optional[str]: + """Build an OData ``$filter`` expression for messages. + + Args: + agent_id: Filter by agent identifier. + invoker_id: Filter by invoker/user identifier. + message_group: Filter by message group. + role: Filter by message role (USER, ASSISTANT, SYSTEM, TOOL). + filter_clauses: Additional ``contains`` predicates. + + Returns: + A ``$filter`` string, or ``None`` if no filters are requested. + """ + parts: list[str] = [] + if agent_id: + parts.append(f"agentID eq '{_escape_odata_string(agent_id)}'") + if invoker_id: + parts.append(f"invokerID eq '{_escape_odata_string(invoker_id)}'") + if message_group: + parts.append(f"messageGroup eq '{_escape_odata_string(message_group)}'") + if role: + parts.append(f"role eq '{_escape_odata_string(role)}'") + if filter_clauses: + parts.extend(build_contains_clauses(filter_clauses)) + return " and ".join(parts) if parts else None + + +def build_list_params( + *, + filter_expr: Optional[str] = None, + search: Optional[str] = None, + select: Optional[str] = None, + top: Optional[int] = None, + skip: Optional[int] = None, + order_by: Optional[str] = None, + count: bool = False, +) -> dict[str, str]: + """Build a dictionary of OData query parameters for list operations. + + Args: + filter_expr: OData ``$filter`` expression string. + search: Free-text search expression (``$search``). + select: Comma-separated list of properties to return (``$select``). + top: Maximum number of results (``$top``). + skip: Number of results to skip for pagination (``$skip``). + order_by: Sort field and direction, e.g. ``"createTimestamp desc"`` (``$orderby``). + count: Whether to request the total count (``$count=true``). + + Returns: + A dict of query parameter name → value strings. + """ + params: dict[str, str] = {} + if filter_expr: + params["$filter"] = filter_expr + if search: + params["$search"] = search + if select: + # The server's audit log handler requires agentID and invokerID to be + # present on every read. Ensure they are always included in $select to + # avoid a 500 error from a NULL constraint on the access log table. + required = {"agentID", "invokerID"} + fields = {f.strip() for f in select.split(",")} + fields |= required + params["$select"] = ",".join(sorted(fields)) + if top is not None: + params["$top"] = str(top) + if skip is not None: + params["$skip"] = str(skip) + if order_by: + params["$orderby"] = order_by + if count: + params["$count"] = "true" + return params + + +def extract_value_and_count(response: dict) -> tuple[list[dict], Optional[int]]: + """Extract the items array and optional count from a list response. + + Supports both OData v4 (``value`` / ``@odata.count``) and the Agent Memory + service format (``data`` / ``count``). + + Args: + response: The parsed JSON response from a list endpoint. + + Returns: + A tuple of (list of item dicts, total count or None). + """ + # OData v4 standard uses "value" key; fall back to "data" for compatibility + items: list[dict] = response.get("value", response.get("data", [])) + total: Optional[int] = response.get( + "@odata.count", response.get("@count", response.get("count")) + ) + return items, total diff --git a/src/sap_cloud_sdk/core/telemetry/module.py b/src/sap_cloud_sdk/core/telemetry/module.py index 686e9d3..502d216 100644 --- a/src/sap_cloud_sdk/core/telemetry/module.py +++ b/src/sap_cloud_sdk/core/telemetry/module.py @@ -9,6 +9,7 @@ class Module(str, Enum): AICORE = "aicore" AUDITLOG = "auditlog" AUDITLOG_NG = "auditlog_ng" + AGENT_MEMORY = "agent_memory" DESTINATION = "destination" OBJECTSTORE = "objectstore" DMS = "dms" diff --git a/src/sap_cloud_sdk/core/telemetry/operation.py b/src/sap_cloud_sdk/core/telemetry/operation.py index 9cb83cf..bd13631 100644 --- a/src/sap_cloud_sdk/core/telemetry/operation.py +++ b/src/sap_cloud_sdk/core/telemetry/operation.py @@ -98,5 +98,20 @@ class Operation(str, Enum): DMS_APPEND_CONTENT_STREAM = "cmis_append_content_stream" DMS_CMIS_QUERY = "cmis_query" + # Agent Memory Operations + AGENT_MEMORY_ADD_MEMORY = "add_memory" + AGENT_MEMORY_GET_MEMORY = "get_memory" + AGENT_MEMORY_UPDATE_MEMORY = "update_memory" + AGENT_MEMORY_DELETE_MEMORY = "delete_memory" + AGENT_MEMORY_LIST_MEMORIES = "list_memories" + AGENT_MEMORY_COUNT_MEMORIES = "count_memories" + AGENT_MEMORY_SEARCH_MEMORIES = "search_memories" + AGENT_MEMORY_ADD_MESSAGE = "add_message" + AGENT_MEMORY_GET_MESSAGE = "get_message" + AGENT_MEMORY_DELETE_MESSAGE = "delete_message" + AGENT_MEMORY_LIST_MESSAGES = "list_messages" + AGENT_MEMORY_GET_RETENTION_CONFIG = "get_retention_config" + AGENT_MEMORY_UPDATE_RETENTION_CONFIG = "update_retention_config" + def __str__(self) -> str: return self.value diff --git a/tests/agent_memory/__init__.py b/tests/agent_memory/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/agent_memory/integration/__init__.py b/tests/agent_memory/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/agent_memory/integration/agentmemory.feature b/tests/agent_memory/integration/agentmemory.feature new file mode 100644 index 0000000..dbe3c56 --- /dev/null +++ b/tests/agent_memory/integration/agentmemory.feature @@ -0,0 +1,91 @@ +Feature: Agent Memory Service Integration (v1 API) + + Background: + Given a configured Agent Memory client + + # ── Memory CRUD ───────────────────────────────────────────────────────────── + + Scenario: Create a new memory + When I create a memory with agent "test-agent" and invoker "test-user" and content "User prefers dark mode" + Then the memory should have a non-empty id + And the memory should have agent_id "test-agent" + And the memory should have invoker_id "test-user" + And the memory should have content "User prefers dark mode" + + Scenario: Get an existing memory + Given a memory exists with agent "test-agent" and invoker "test-user" and content "Test memory" + When I get the memory by id + Then the returned memory should match the created memory + + Scenario: Update memory content + Given a memory exists with agent "test-agent" and invoker "test-user" and content "Original content" + When I update the memory content to "Updated content" + Then the memory should have content "Updated content" + + Scenario: List memories with filter + Given a memory exists with agent "test-agent" and invoker "test-user" and content "Listed memory" + When I list memories filtered by agent "test-agent" + Then the result should contain at least one memory + And the total count should be a positive number + + Scenario: Delete a memory + Given a memory exists with agent "test-agent" and invoker "test-user" and content "To be deleted" + When I delete the memory + Then the memory should no longer exist + + # ── Memory search ─────────────────────────────────────────────────────────── + + Scenario: Search memories by semantic query + Given a memory exists with agent "test-agent" and invoker "test-user" and content "The user loves dark mode and dark themes" + When I search for memories with query "dark mode preference" + Then the search result should contain at least one result + And each result should have a non-empty content + + # ── Message CRUD ──────────────────────────────────────────────────────────── + + Scenario: Create and get a message + When I create a message with agent "test-agent" invoker "test-user" group "conv-1" role "USER" content "Hello!" + Then the message should have a non-empty id + And the message should have role "USER" + And the message should have content "Hello!" + + Scenario: List messages with filter + Given a message exists with agent "test-agent" invoker "test-user" group "conv-list" role "USER" content "Listed message" + When I list messages filtered by agent "test-agent" and group "conv-list" + Then the result should contain at least one message + And the total count should be a positive number + + Scenario: Delete a message + Given a message exists with agent "test-agent" invoker "test-user" group "conv-del" role "USER" content "To be deleted" + When I delete the message + Then the message should no longer exist + + # ── Admin — Retention Config ──────────────────────────────────────────────── + + Scenario: Get retention config + When I get the retention config + Then the retention config should have a non-empty id + + Scenario: Update retention config + When I update the retention config with message_days 30 and memory_days 90 + Then the retention config should have message_days 30 + And the retention config should have memory_days 90 + + # ── Bulk / utility operations ──────────────────────────────────────────────── + + Scenario: Count memories for an agent and invoker + Given a memory exists with agent "test-agent" and invoker "test-user" and content "Count test memory" + When I count memories for agent "test-agent" and invoker "test-user" + Then the memory count should be a positive number + + # ── Filter ─────────────────────────────────────────────────────────────────── + + Scenario: Filter memories by content substring + Given a memory exists with agent "test-agent" and invoker "test-user" and content "The user prefers dark mode" + When I list memories filtered by content containing "dark mode" + Then the result should contain the memory with content "The user prefers dark mode" + + Scenario: Filter messages by metadata substring + Given a message exists with agent "test-agent" invoker "test-user" group "conv-filter" role "USER" content "filter-test-message" and metadata "filter-marker" + When I list messages filtered by metadata containing "filter-marker" + Then the result should contain the message with content "filter-test-message" \ No newline at end of file diff --git a/tests/agent_memory/integration/conftest.py b/tests/agent_memory/integration/conftest.py new file mode 100644 index 0000000..9b26d5c --- /dev/null +++ b/tests/agent_memory/integration/conftest.py @@ -0,0 +1,30 @@ +"""Integration test fixtures for the Agent Memory service. + +Set the following environment variables before running integration tests: + + CLOUD_SDK_CFG_AGENT_MEMORY_DEFAULT_URL Base URL of the Agent Memory service + CLOUD_SDK_CFG_AGENT_MEMORY_DEFAULT_AUTH_URL OAuth2 authorization server base URL + CLOUD_SDK_CFG_AGENT_MEMORY_DEFAULT_CLIENTID OAuth2 client ID + CLOUD_SDK_CFG_AGENT_MEMORY_DEFAULT_CLIENTSECRET OAuth2 client secret +""" + +from pathlib import Path + +import pytest +from dotenv import load_dotenv + +from sap_cloud_sdk.agent_memory import create_client +from sap_cloud_sdk.agent_memory.client import AgentMemoryClient + + +@pytest.fixture(scope="session") +def agent_memory_client() -> AgentMemoryClient: + """Create a real AgentMemoryClient from environment variables.""" + env_file = Path(__file__).parents[3] / ".env_integration_tests" + if env_file.exists(): + load_dotenv(env_file, override=True) + + try: + return create_client() + except Exception as e: + pytest.fail(f"Failed to create Agent Memory client for integration tests: {e}") diff --git a/tests/agent_memory/integration/test_agentmemory_bdd.py b/tests/agent_memory/integration/test_agentmemory_bdd.py new file mode 100644 index 0000000..fb7c450 --- /dev/null +++ b/tests/agent_memory/integration/test_agentmemory_bdd.py @@ -0,0 +1,501 @@ +"""BDD integration tests for the Agent Memory service (v1 API). + +Run against a live service: + + AGENT_MEMORY_BASE_URL=http://localhost:3000 pytest tests/agent_memory/integration + +Or against the deployed BTP service (with OAuth2): + + AGENT_MEMORY_BASE_URL=https://... \\ + AGENT_MEMORY_TOKEN_URL=https://... \\ + AGENT_MEMORY_CLIENT_ID=... \\ + AGENT_MEMORY_CLIENT_SECRET=... \\ + pytest tests/agent_memory/integration +""" + +import pytest +from pytest_bdd import given, scenario, then, when + +from sap_cloud_sdk.agent_memory.client import AgentMemoryClient + +# -- Scenarios ----------------------------------------------------------------- + + +@scenario("agentmemory.feature", "Create a new memory") +def test_add_memory(): + pass + + +@scenario("agentmemory.feature", "Get an existing memory") +def test_get_memory(): + pass + + +@scenario("agentmemory.feature", "Update memory content") +def test_update_memory(): + pass + + +@scenario("agentmemory.feature", "List memories with filter") +def test_list_memories(): + pass + + +@scenario("agentmemory.feature", "Delete a memory") +def test_delete_memory(): + pass + + +@scenario("agentmemory.feature", "Search memories by semantic query") +def test_search_memories(): + pass + + +@scenario("agentmemory.feature", "Create and get a message") +def test_add_message(): + pass + + +@scenario("agentmemory.feature", "List messages with filter") +def test_list_messages(): + pass + + +@scenario("agentmemory.feature", "Delete a message") +def test_delete_message(): + pass + + +@scenario("agentmemory.feature", "Get retention config") +def test_get_retention_config(): + pass + + +@scenario("agentmemory.feature", "Update retention config") +def test_update_retention_config(): + pass + + +@scenario("agentmemory.feature", "Count memories for an agent and invoker") +def test_count_memories(): + pass + + +@scenario("agentmemory.feature", "Filter memories by content substring") +def test_filter_memories_by_content(): + pass + + +@scenario("agentmemory.feature", "Filter messages by metadata substring") +def test_filter_messages_by_metadata(): + pass + + +# -- Fixtures / state --------------------------------------------------------- + + +@pytest.fixture +def context(): + return {} + + +# -- Given steps --------------------------------------------------------------- + + +@given("a configured Agent Memory client") +def configured_client(context, agent_memory_client): + context["client"] = agent_memory_client + + +@given( + 'a memory exists with agent "test-agent" and invoker "test-user" and content "Test memory"' +) +def memory_exists_test(context, agent_memory_client): + context["client"] = agent_memory_client + context["memory"] = agent_memory_client.add_memory( + "test-agent", + "test-user", + "Test memory", + ) + + +@given( + 'a memory exists with agent "test-agent" and invoker "test-user" and content "Original content"' +) +def memory_exists_original(context, agent_memory_client): + context["client"] = agent_memory_client + context["memory"] = agent_memory_client.add_memory( + "test-agent", + "test-user", + "Original content", + ) + + +@given( + 'a memory exists with agent "test-agent" and invoker "test-user" and content "Listed memory"' +) +def memory_exists_listed(context, agent_memory_client): + context["client"] = agent_memory_client + context["memory"] = agent_memory_client.add_memory( + "test-agent", + "test-user", + "Listed memory", + ) + + +@given( + 'a memory exists with agent "test-agent" and invoker "test-user" and content "To be deleted"' +) +def memory_exists_delete(context, agent_memory_client): + context["client"] = agent_memory_client + context["memory"] = agent_memory_client.add_memory( + "test-agent", + "test-user", + "To be deleted", + ) + + +@given( + 'a memory exists with agent "test-agent" and invoker "test-user" and content "The user loves dark mode and dark themes"' +) +def memory_exists_search(context, agent_memory_client): + context["client"] = agent_memory_client + context["memory"] = agent_memory_client.add_memory( + "test-agent", + "test-user", + "The user loves dark mode and dark themes", + ) + + +@given( + 'a message exists with agent "test-agent" invoker "test-user" group "conv-list" role "USER" content "Listed message"' +) +def message_exists_list(context, agent_memory_client): + context["client"] = agent_memory_client + context["message"] = agent_memory_client.add_message( + "test-agent", + "test-user", + "conv-list", + "USER", + "Listed message", + ) + + +@given( + 'a message exists with agent "test-agent" invoker "test-user" group "conv-del" role "USER" content "To be deleted"' +) +def message_exists_delete(context, agent_memory_client): + context["client"] = agent_memory_client + context["message"] = agent_memory_client.add_message( + "test-agent", + "test-user", + "conv-del", + "USER", + "To be deleted", + ) + + +# -- When steps ---------------------------------------------------------------- + + +@when( + 'I create a memory with agent "test-agent" and invoker "test-user" and content "User prefers dark mode"' +) +def add_memory(context): + client: AgentMemoryClient = context["client"] + context["memory"] = client.add_memory( + "test-agent", + "test-user", + "User prefers dark mode", + ) + + +@when("I get the memory by id") +def get_memory(context): + client: AgentMemoryClient = context["client"] + context["fetched_memory"] = client.get_memory(context["memory"].id) + + +@when('I update the memory content to "Updated content"') +def update_memory(context): + client: AgentMemoryClient = context["client"] + client.update_memory(context["memory"].id, content="Updated content") + context["memory"] = client.get_memory(context["memory"].id) + + +@when('I list memories filtered by agent "test-agent"') +def list_memories(context): + client: AgentMemoryClient = context["client"] + context["memories"] = client.list_memories(agent_id="test-agent") + context["total"] = client.count_memories(agent_id="test-agent") + + +@when("I delete the memory") +def delete_memory(context): + client: AgentMemoryClient = context["client"] + client.delete_memory(context["memory"].id) + context["deleted_memory_id"] = context["memory"].id + + +@when('I search for memories with query "dark mode preference"') +def search_memories(context): + client: AgentMemoryClient = context["client"] + context["search_results"] = client.search_memories( + agent_id="test-agent", + invoker_id="test-user", + query="dark mode preference", + threshold=0.5, + limit=10, + ) + + +@when( + 'I create a message with agent "test-agent" invoker "test-user" group "conv-1" role "USER" content "Hello!"' +) +def add_message(context): + client: AgentMemoryClient = context["client"] + context["message"] = client.add_message( + "test-agent", + "test-user", + "conv-1", + "USER", + "Hello!", + ) + + +@when('I list messages filtered by agent "test-agent" and group "conv-list"') +def list_messages(context): + client: AgentMemoryClient = context["client"] + context["messages"] = client.list_messages( + agent_id="test-agent", + message_group="conv-list", + ) + context["total"] = len(context["messages"]) + + +@when("I delete the message") +def delete_message(context): + client: AgentMemoryClient = context["client"] + client.delete_message(context["message"].id) + context["deleted_message_id"] = context["message"].id + + +# -- Then steps ---------------------------------------------------------------- + + +@then("the memory should have a non-empty id") +def check_memory_id(context): + assert context["memory"].id != "" + + +@then('the memory should have agent_id "test-agent"') +def check_memory_agent_id(context): + assert context["memory"].agent_id == "test-agent" + + +@then('the memory should have invoker_id "test-user"') +def check_memory_invoker_id(context): + assert context["memory"].invoker_id == "test-user" + + +@then('the memory should have content "User prefers dark mode"') +def check_memory_content_dark(context): + assert context["memory"].content == "User prefers dark mode" + + +@then('the memory should have content "Updated content"') +def check_memory_content_updated(context): + assert context["memory"].content == "Updated content" + + +@then("the returned memory should match the created memory") +def check_fetched_memory(context): + assert context["fetched_memory"].id == context["memory"].id + assert context["fetched_memory"].content == context["memory"].content + + +@then("the result should contain at least one memory") +def check_memories_not_empty(context): + assert len(context["memories"]) >= 1 + + +@then("the total count should be a positive number") +def check_total_positive(context): + assert context["total"] is not None + assert context["total"] >= 1 + + +@then("the memory should no longer exist") +def check_memory_deleted(context): + from sap_cloud_sdk.agent_memory.exceptions import AgentMemoryNotFoundError + + client: AgentMemoryClient = context["client"] + with pytest.raises(AgentMemoryNotFoundError): + client.get_memory(context["deleted_memory_id"]) + + +@then("the search result should contain at least one result") +def check_search_not_empty(context): + assert len(context["search_results"]) >= 1 + + +@then("each result should have a non-empty content") +def check_result_content(context): + for result in context["search_results"]: + assert result.content != "" + + +@then("the message should have a non-empty id") +def check_message_id(context): + assert context["message"].id != "" + + +@then('the message should have role "USER"') +def check_message_role(context): + assert context["message"].role == "USER" + + +@then('the message should have content "Hello!"') +def check_message_content(context): + assert context["message"].content == "Hello!" + + +@then("the result should contain at least one message") +def check_messages_not_empty(context): + assert len(context["messages"]) >= 1 + + +@then("the message should no longer exist") +def check_message_deleted(context): + from sap_cloud_sdk.agent_memory.exceptions import AgentMemoryNotFoundError + + client: AgentMemoryClient = context["client"] + with pytest.raises(AgentMemoryNotFoundError): + client.get_message(context["deleted_message_id"]) + + +# -- Admin: Retention Config steps --------------------------------------------- + + +@when("I get the retention config") +def get_retention_config(context): + client: AgentMemoryClient = context["client"] + context["retention_config"] = client.get_retention_config() + + +@when("I update the retention config with message_days 30 and memory_days 90") +def update_retention_config(context): + client: AgentMemoryClient = context["client"] + client.update_retention_config(message_days=30, memory_days=90) + context["retention_config"] = client.get_retention_config() + + +@then("the retention config should have a non-empty id") +def check_retention_config_id(context): + assert context["retention_config"].id != "" + + +@then("the retention config should have message_days 30") +def check_retention_message_days(context): + assert context["retention_config"].message_days == 30 + + +@then("the retention config should have memory_days 90") +def check_retention_memory_days(context): + assert context["retention_config"].memory_days == 90 + + +# -- Bulk / utility steps ------------------------------------------------------- + + +@given( + 'a memory exists with agent "test-agent" and invoker "test-user" and content "Count test memory"' +) +def memory_exists_count(context, agent_memory_client): + context["client"] = agent_memory_client + context["memory"] = agent_memory_client.add_memory( + "test-agent", + "test-user", + "Count test memory", + ) + + +@when('I count memories for agent "test-agent" and invoker "test-user"') +def count_memories(context): + client: AgentMemoryClient = context["client"] + context["memory_count"] = client.count_memories( + agent_id="test-agent", + invoker_id="test-user", + ) + + +@then("the memory count should be a positive number") +def check_memory_count_positive(context): + assert context["memory_count"] >= 1 + + +# -- Filter steps --------------------------------------------------------------- + + +@given( + 'a memory exists with agent "test-agent" and invoker "test-user" and content "The user prefers dark mode"' +) +def memory_exists_dark_mode(context, agent_memory_client): + context["client"] = agent_memory_client + context["memory"] = agent_memory_client.add_memory( + "test-agent", + "test-user", + "The user prefers dark mode", + ) + + +@given( + 'a message exists with agent "test-agent" invoker "test-user" group "conv-filter" role "USER" content "filter-test-message" and metadata "filter-marker"' +) +def message_exists_filter(context, agent_memory_client): + context["client"] = agent_memory_client + context["message"] = agent_memory_client.add_message( + "test-agent", + "test-user", + "conv-filter", + "USER", + "filter-test-message", + metadata={"tag": "filter-marker"}, + ) + + +@when('I list memories filtered by content containing "dark mode"') +def list_memories_by_content(context): + from sap_cloud_sdk.agent_memory import FilterDefinition + + client: AgentMemoryClient = context["client"] + context["memories"] = client.list_memories( + agent_id="test-agent", + invoker_id="test-user", + filters=[FilterDefinition(target="content", contains="dark mode")], + ) + + +@when('I list messages filtered by metadata containing "filter-marker"') +def list_messages_by_metadata(context): + from sap_cloud_sdk.agent_memory import FilterDefinition + + client: AgentMemoryClient = context["client"] + context["messages"] = client.list_messages( + agent_id="test-agent", + invoker_id="test-user", + message_group="conv-filter", + filters=[FilterDefinition(target="metadata", contains="filter-marker")], + ) + + +@then('the result should contain the memory with content "The user prefers dark mode"') +def check_memory_content_in_results(context): + contents = [m.content for m in context["memories"]] + assert "The user prefers dark mode" in contents + + +@then('the result should contain the message with content "filter-test-message"') +def check_message_content_in_results(context): + contents = [m.content for m in context["messages"]] + assert "filter-test-message" in contents diff --git a/tests/agent_memory/unit/__init__.py b/tests/agent_memory/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/agent_memory/unit/test_client.py b/tests/agent_memory/unit/test_client.py new file mode 100644 index 0000000..feb0b18 --- /dev/null +++ b/tests/agent_memory/unit/test_client.py @@ -0,0 +1,999 @@ +"""Unit tests for AgentMemoryClient operations (v1 API).""" + +import pytest +from unittest.mock import MagicMock, patch + +from sap_cloud_sdk.agent_memory._endpoints import ( + MEMORIES, + MEMORY_SEARCH, + MESSAGES, + RETENTION_CONFIG, +) +from sap_cloud_sdk.agent_memory._http_transport import HttpTransport +from sap_cloud_sdk.agent_memory._models import ( + Memory, + Message, + MessageRole, + RetentionConfig, + SearchResult, +) +from sap_cloud_sdk.agent_memory.client import AgentMemoryClient +from sap_cloud_sdk.agent_memory import create_client, FilterDefinition +from sap_cloud_sdk.agent_memory.config import AgentMemoryConfig +from sap_cloud_sdk.agent_memory.exceptions import AgentMemoryValidationError + + +def _make_client() -> tuple[AgentMemoryClient, MagicMock]: + """Return an AgentMemoryClient with a mocked transport layer.""" + transport = MagicMock(spec=HttpTransport) + client = AgentMemoryClient(transport) + return client, transport + + +# ── create_client factory ───────────────────────────────────────────────────── + + +class TestCreateClient: + + def test_uses_provided_config(self): + """Factory accepts an explicit config object.""" + config = AgentMemoryConfig(base_url="http://localhost:3000") + with patch("sap_cloud_sdk.agent_memory.HttpTransport") as MockTransport: + MockTransport.return_value = MagicMock(spec=HttpTransport) + client = create_client(config=config) + assert isinstance(client, AgentMemoryClient) + + def test_reads_env_when_no_config_provided(self, monkeypatch): + """Factory falls back to environment variables when no config given.""" + monkeypatch.setenv("CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_APPLICATION_URL", "http://memory.example.com") + monkeypatch.setenv("CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA_URL", "http://auth.example.com") + monkeypatch.setenv("CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA_CLIENTID", "client-id") + monkeypatch.setenv("CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA_CLIENTSECRET", "client-secret") + with patch("sap_cloud_sdk.agent_memory.HttpTransport") as MockTransport: + MockTransport.return_value = MagicMock(spec=HttpTransport) + client = create_client() + assert isinstance(client, AgentMemoryClient) + + +# ── Memory CRUD operations ──────────────────────────────────────────────────── + + +class TestMemoryCRUD: + + def test_add_memory_posts_correct_payload(self): + """add_memory sends required and optional fields in the POST body.""" + client, mock_transport = _make_client() + mock_transport.post.return_value = { + "id": "mem-1", + "agentID": "agent-a", + "invokerID": "user-b", + "content": "some memory", + "createType": "DIRECT", + } + + memory = client.add_memory("agent-a", "user-b", "some memory") + + assert isinstance(memory, Memory) + assert memory.id == "mem-1" + payload = mock_transport.post.call_args[1]["json"] + assert payload["agentID"] == "agent-a" + assert payload["invokerID"] == "user-b" + assert payload["content"] == "some memory" + + def test_add_memory_with_metadata(self): + """Optional metadata is included in the POST body when provided.""" + client, mock_transport = _make_client() + mock_transport.post.return_value = { + "id": "mem-1", "agentID": "a", "invokerID": "u", "content": "x", + } + + client.add_memory("a", "u", "x", metadata={"key": "val"}) + + payload = mock_transport.post.call_args[1]["json"] + assert payload["metadata"] == {"key": "val"} + + def test_add_memory_excludes_none_optionals(self): + """None-valued optional fields are omitted from the POST body.""" + client, mock_transport = _make_client() + mock_transport.post.return_value = { + "id": "mem-1", "agentID": "a", "invokerID": "u", "content": "x", + } + + client.add_memory("a", "u", "x") + + payload = mock_transport.post.call_args[1]["json"] + assert "metadata" not in payload + assert "createType" not in payload + + def test_add_memory_posts_to_memories_endpoint(self): + """add_memory sends the POST to the MEMORIES endpoint.""" + client, mock_transport = _make_client() + mock_transport.post.return_value = { + "id": "mem-1", "agentID": "a", "invokerID": "u", "content": "x", + } + + client.add_memory("a", "u", "x") + + call_path = mock_transport.post.call_args[0][0] + assert call_path == MEMORIES + + def test_get_memory_calls_get_with_memory_id(self): + """get_memory constructs the correct path with the memory ID.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = { + "id": "mem-1", "agentID": "a", "invokerID": "u", "content": "hello", + } + + memory = client.get_memory("mem-1") + + assert memory.id == "mem-1" + call_path = mock_transport.get.call_args[0][0] + assert call_path == f"{MEMORIES}(mem-1)" + + def test_update_memory_calls_patch(self): + """update_memory sends a PATCH with the updated fields.""" + client, mock_transport = _make_client() + + client.update_memory("mem-1", content="updated") + + mock_transport.patch.assert_called_once() + payload = mock_transport.patch.call_args[1]["json"] + assert payload["content"] == "updated" + + def test_update_memory_excludes_none_fields(self): + """update_memory omits None-valued optional fields from the PATCH body.""" + client, mock_transport = _make_client() + + client.update_memory("mem-1", content="x") + + payload = mock_transport.patch.call_args[1]["json"] + assert "metadata" not in payload + + def test_update_memory_with_metadata_only(self): + """update_memory supports updating metadata without content.""" + client, mock_transport = _make_client() + + client.update_memory("mem-1", metadata={"key": "new-meta"}) + + payload = mock_transport.patch.call_args[1]["json"] + assert payload["metadata"] == {"key": "new-meta"} + assert "content" not in payload + + def test_delete_memory_calls_delete(self): + """delete_memory sends a DELETE to the correct path.""" + client, mock_transport = _make_client() + + client.delete_memory("mem-1") + + mock_transport.delete.assert_called_once() + call_path = mock_transport.delete.call_args[0][0] + assert call_path == f"{MEMORIES}(mem-1)" + + +# ── Memory listing ──────────────────────────────────────────────────────────── + + +class TestListMemories: + + def test_returns_list_of_memories(self): + """list_memories returns a list of Memory objects.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = { + "value": [ + {"id": "m1", "agentID": "a", "invokerID": "u", "content": "memory 1"}, + ], + } + + memories = client.list_memories(agent_id="a", invoker_id="u") + + assert len(memories) == 1 + assert isinstance(memories[0], Memory) + + def test_passes_filter_for_agent_and_invoker(self): + """Convenience agent_id/invoker_id args are converted to $filter.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_memories(agent_id="agent-x", invoker_id="user-y") + + params = mock_transport.get.call_args[1]["params"] + assert "agentID eq 'agent-x'" in params["$filter"] + assert "invokerID eq 'user-y'" in params["$filter"] + + def test_default_limit_is_50(self): + """Default limit is 50 ($top=50).""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_memories() + + params = mock_transport.get.call_args[1]["params"] + assert params["$top"] == "50" + + def test_custom_limit(self): + """Custom limit is forwarded as $top.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_memories(limit=5) + + params = mock_transport.get.call_args[1]["params"] + assert params["$top"] == "5" + + def test_empty_list(self): + """list_memories handles empty responses correctly.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + memories = client.list_memories() + + assert len(memories) == 0 + + def test_offset_passes_skip_param(self): + """Non-zero offset is forwarded as $skip.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_memories(offset=50) + + params = mock_transport.get.call_args[1]["params"] + assert params["$skip"] == "50" + + def test_zero_offset_omits_skip_param(self): + """Default offset of 0 does not add $skip to the request.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_memories() + + params = mock_transport.get.call_args[1]["params"] + assert "$skip" not in params + + def test_filter_metadata_contains_adds_contains_clause(self): + """A metadata FilterDefinition produces a contains(metadata, ...) expression.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_memories( + filters=[FilterDefinition(target="metadata", contains="john")], + ) + + params = mock_transport.get.call_args[1]["params"] + assert "contains(metadata, 'john')" in params["$filter"] + + def test_filter_content_contains_adds_contains_clause(self): + """A content FilterDefinition produces a contains(content, ...) expression.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_memories( + filters=[FilterDefinition(target="content", contains="dark mode")], + ) + + params = mock_transport.get.call_args[1]["params"] + assert "contains(content, 'dark mode')" in params["$filter"] + + def test_filter_multiple_clauses_joined_with_and(self): + """Multiple FilterDefinitions are joined with 'and' in $filter.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_memories( + filters=[ + FilterDefinition(target="metadata", contains="john"), + FilterDefinition(target="content", contains="user prefers"), + ], + ) + + params = mock_transport.get.call_args[1]["params"] + f = params["$filter"] + assert "contains(metadata, 'john')" in f + assert "contains(content, 'user prefers')" in f + assert " and " in f + + def test_filter_combines_with_agent_and_invoker_filters(self): + """FilterDefinitions are combined with agent_id/invoker_id eq predicates.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_memories( + agent_id="my-agent", + invoker_id="user-1", + filters=[FilterDefinition(target="content", contains="dark mode")], + ) + + params = mock_transport.get.call_args[1]["params"] + f = params["$filter"] + assert "agentID eq 'my-agent'" in f + assert "invokerID eq 'user-1'" in f + assert "contains(content, 'dark mode')" in f + + def test_filter_none_does_not_change_behaviour(self): + """filter=None produces the same $filter as before (no regression).""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_memories(agent_id="a", invoker_id="u", filters=None) + + params = mock_transport.get.call_args[1]["params"] + assert params["$filter"] == "agentID eq 'a' and invokerID eq 'u'" + + +class TestCountMemories: + + def test_returns_count_from_response(self): + """count_memories returns the @odata.count value.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": [], "@odata.count": 42} + + total = client.count_memories(agent_id="a", invoker_id="u") + + assert total == 42 + + def test_sends_top_0_and_count_true(self): + """count_memories uses $top=0 and $count=true.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": [], "@odata.count": 0} + + client.count_memories() + + params = mock_transport.get.call_args[1]["params"] + assert params["$top"] == "0" + assert params["$count"] == "true" + + def test_passes_filter_when_agent_and_invoker_provided(self): + """count_memories forwards agent_id and invoker_id as $filter.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": [], "@odata.count": 3} + + client.count_memories(agent_id="agt", invoker_id="usr") + + params = mock_transport.get.call_args[1]["params"] + assert "agentID eq 'agt'" in params["$filter"] + assert "invokerID eq 'usr'" in params["$filter"] + + def test_returns_zero_when_count_missing(self): + """count_memories returns 0 when count is absent from response.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + total = client.count_memories() + + assert total == 0 + + +# ── Memory search ───────────────────────────────────────────────────────────── + + +class TestSearchMemories: + + def test_returns_results_in_api_order(self): + """search_memories returns results in the order returned by the API.""" + client, mock_transport = _make_client() + mock_transport.post.return_value = { + "value": [ + {"id": "m1", "agentID": "a", "invokerID": "u", "content": "first", "similarity": 0.5}, + {"id": "m2", "agentID": "a", "invokerID": "u", "content": "second", "similarity": 0.9}, + ] + } + + results = client.search_memories("a", "u", "test query") + + assert len(results) == 2 + assert all(isinstance(r, SearchResult) for r in results) + assert results[0].similarity == 0.5 + assert results[1].similarity == 0.9 + + def test_posts_correct_payload(self): + """search_memories sends the correct payload to the search endpoint.""" + client, mock_transport = _make_client() + mock_transport.post.return_value = {"value": []} + + client.search_memories("agent-a", "user-b", "my query", threshold=0.7, limit=5) + + call_path = mock_transport.post.call_args[0][0] + assert call_path == MEMORY_SEARCH + payload = mock_transport.post.call_args[1]["json"] + assert payload["agentID"] == "agent-a" + assert payload["invokerID"] == "user-b" + assert payload["query"] == "my query" + assert payload["threshold"] == 0.7 + assert payload["top"] == 5 + + def test_empty_results(self): + """search_memories handles empty search results.""" + client, mock_transport = _make_client() + mock_transport.post.return_value = {"value": []} + + results = client.search_memories("a", "u", "empty query") + + assert len(results) == 0 + + def test_uses_default_threshold_and_limit(self): + """search_memories uses default threshold=0.6 and limit=10.""" + client, mock_transport = _make_client() + mock_transport.post.return_value = {"value": []} + + client.search_memories("a", "u", "query") + + payload = mock_transport.post.call_args[1]["json"] + assert payload["threshold"] == 0.6 + assert payload["top"] == 10 + assert "skip" not in payload + + +# ── Message operations ──────────────────────────────────────────────────────── + + +class TestMessageCRUD: + + def test_add_message_posts_correct_payload(self): + """add_message sends required fields in the POST body.""" + client, mock_transport = _make_client() + mock_transport.post.return_value = { + "id": "msg-1", + "agentID": "agent-a", + "invokerID": "user-b", + "messageGroup": "conv-1", + "role": "USER", + "content": "Hello!", + } + + message = client.add_message( + "agent-a", "user-b", "conv-1", MessageRole.USER, "Hello!", + ) + + assert isinstance(message, Message) + assert message.id == "msg-1" + assert message.role == "USER" + payload = mock_transport.post.call_args[1]["json"] + assert payload["agentID"] == "agent-a" + assert payload["invokerID"] == "user-b" + assert payload["messageGroup"] == "conv-1" + assert payload["role"] == "USER" + assert payload["content"] == "Hello!" + + def test_add_message_posts_to_messages_endpoint(self): + """add_message sends the POST to the MESSAGES endpoint.""" + client, mock_transport = _make_client() + mock_transport.post.return_value = { + "id": "msg-1", "agentID": "a", "invokerID": "u", + "messageGroup": "g", "role": "USER", "content": "hi", + } + + client.add_message("a", "u", "g", MessageRole.USER, "hi") + + call_path = mock_transport.post.call_args[0][0] + assert call_path == MESSAGES + + def test_add_message_with_metadata(self): + """Optional metadata is included when provided.""" + client, mock_transport = _make_client() + mock_transport.post.return_value = { + "id": "msg-1", "agentID": "a", "invokerID": "u", + "messageGroup": "g", "role": "USER", "content": "hi", + "metadata": {"key": "val"}, + } + + client.add_message("a", "u", "g", MessageRole.USER, "hi", metadata={"key": "val"}) + + payload = mock_transport.post.call_args[1]["json"] + assert payload["metadata"] == {"key": "val"} + + def test_add_message_excludes_none_metadata(self): + """None-valued metadata is omitted from the POST body.""" + client, mock_transport = _make_client() + mock_transport.post.return_value = { + "id": "msg-1", "agentID": "a", "invokerID": "u", + "messageGroup": "g", "role": "USER", "content": "hi", + } + + client.add_message("a", "u", "g", MessageRole.USER, "hi") + + payload = mock_transport.post.call_args[1]["json"] + assert "metadata" not in payload + + def test_get_message_calls_get_with_message_id(self): + """get_message constructs the correct path with the message ID.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = { + "id": "msg-1", "agentID": "a", "invokerID": "u", + "messageGroup": "g", "role": "USER", "content": "hi", + } + + message = client.get_message("msg-1") + + assert message.id == "msg-1" + call_path = mock_transport.get.call_args[0][0] + assert call_path == f"{MESSAGES}(msg-1)" + + def test_delete_message_calls_delete(self): + """delete_message sends a DELETE to the correct path.""" + client, mock_transport = _make_client() + + client.delete_message("msg-1") + + mock_transport.delete.assert_called_once() + call_path = mock_transport.delete.call_args[0][0] + assert call_path == f"{MESSAGES}(msg-1)" + + +# ── Message listing ─────────────────────────────────────────────────────────── + + +class TestListMessages: + + def test_returns_list_of_messages(self): + """list_messages returns a list of Message objects.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = { + "value": [ + { + "id": "msg-1", "agentID": "a", "invokerID": "u", + "messageGroup": "g", "role": "USER", "content": "hi", + }, + ], + } + + messages = client.list_messages(agent_id="a", invoker_id="u") + + assert len(messages) == 1 + assert isinstance(messages[0], Message) + + def test_passes_convenience_filters(self): + """Convenience filters are converted to $filter.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_messages( + agent_id="a", invoker_id="u", + message_group="conv-1", role="USER", + ) + + params = mock_transport.get.call_args[1]["params"] + f = params["$filter"] + assert "agentID eq 'a'" in f + assert "invokerID eq 'u'" in f + assert "messageGroup eq 'conv-1'" in f + assert "role eq 'USER'" in f + + def test_default_limit_is_50(self): + """Default limit is 50 ($top=50).""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_messages() + + params = mock_transport.get.call_args[1]["params"] + assert params["$top"] == "50" + + def test_custom_limit(self): + """Custom limit is forwarded as $top.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_messages(limit=20) + + params = mock_transport.get.call_args[1]["params"] + assert params["$top"] == "20" + + def test_empty_list(self): + """list_messages handles empty responses correctly.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + messages = client.list_messages() + + assert len(messages) == 0 + + def test_offset_passes_skip_param(self): + """Non-zero offset is forwarded as $skip.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_messages(offset=100) + + params = mock_transport.get.call_args[1]["params"] + assert params["$skip"] == "100" + + def test_zero_offset_omits_skip_param(self): + """Default offset of 0 does not add $skip to the request.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_messages() + + params = mock_transport.get.call_args[1]["params"] + assert "$skip" not in params + + def test_filter_metadata_contains_adds_contains_clause(self): + """A metadata FilterDefinition produces a contains(metadata, ...) expression.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_messages( + filters=[FilterDefinition(target="metadata", contains="demo-app")], + ) + + params = mock_transport.get.call_args[1]["params"] + assert "contains(metadata, 'demo-app')" in params["$filter"] + + def test_filter_content_contains_adds_contains_clause(self): + """A content FilterDefinition produces a contains(content, ...) expression.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_messages( + filters=[FilterDefinition(target="content", contains="invoice")], + ) + + params = mock_transport.get.call_args[1]["params"] + assert "contains(content, 'invoice')" in params["$filter"] + + def test_filter_multiple_clauses_joined_with_and(self): + """Multiple FilterDefinitions are joined with 'and' in $filter.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_messages( + filters=[ + FilterDefinition(target="metadata", contains="john"), + FilterDefinition(target="content", contains="user prefers"), + ], + ) + + params = mock_transport.get.call_args[1]["params"] + f = params["$filter"] + assert "contains(metadata, 'john')" in f + assert "contains(content, 'user prefers')" in f + assert " and " in f + + def test_filter_combines_with_convenience_filters(self): + """FilterDefinitions are combined with all convenience filter predicates.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_messages( + agent_id="a", + invoker_id="u", + message_group="g", + role="USER", + filters=[FilterDefinition(target="content", contains="hello")], + ) + + params = mock_transport.get.call_args[1]["params"] + f = params["$filter"] + assert "agentID eq 'a'" in f + assert "invokerID eq 'u'" in f + assert "messageGroup eq 'g'" in f + assert "role eq 'USER'" in f + assert "contains(content, 'hello')" in f + + def test_filter_none_does_not_change_behaviour(self): + """filter=None produces the same $filter as before (no regression).""" + client, mock_transport = _make_client() + mock_transport.get.return_value = {"value": []} + + client.list_messages(agent_id="a", invoker_id="u", filters=None) + + params = mock_transport.get.call_args[1]["params"] + assert params["$filter"] == "agentID eq 'a' and invokerID eq 'u'" + + +# ── Admin: Retention Config ─────────────────────────────────────────────────────── + + +class TestRetentionConfig: + + def test_get_retention_config(self): + """get_retention_config sends GET to the retentionConfig endpoint.""" + client, mock_transport = _make_client() + mock_transport.get.return_value = { + "id": 1, "messageDays": 30, "memoryDays": 90, + "usageLogDays": 180, + "createTimestamp": "2025-01-01T00:00:00Z", + "updateTimestamp": "2025-01-02T00:00:00Z", + } + + rc = client.get_retention_config() + + assert isinstance(rc, RetentionConfig) + assert rc.id == 1 + assert rc.message_days == 30 + assert rc.memory_days == 90 + assert rc.usage_log_days == 180 + call_path = mock_transport.get.call_args[0][0] + assert call_path == RETENTION_CONFIG + + def test_update_retention_config(self): + """update_retention_config sends PATCH with updated fields.""" + client, mock_transport = _make_client() + + client.update_retention_config(message_days=60) + + mock_transport.patch.assert_called_once() + call_path = mock_transport.patch.call_args[0][0] + assert call_path == RETENTION_CONFIG + payload = mock_transport.patch.call_args[1]["json"] + assert payload["messageDays"] == 60 + assert "memoryDays" not in payload + + def test_update_retention_config_excludes_none_fields(self): + """update_retention_config omits None-valued fields from PATCH body.""" + client, mock_transport = _make_client() + + client.update_retention_config(memory_days=90, usage_log_days=180) + + payload = mock_transport.patch.call_args[1]["json"] + assert "messageDays" not in payload + assert payload["memoryDays"] == 90 + assert payload["usageLogDays"] == 180 + + +# ── Context manager ─────────────────────────────────────────────────────────── + + +class TestContextManager: + + def test_close_delegates_to_transport(self): + """close() delegates to the transport's close method.""" + client, mock_transport = _make_client() + + client.close() + + mock_transport.close.assert_called_once() + + def test_context_manager_closes_on_exit(self): + """Using the client as a context manager closes it on __exit__.""" + transport = MagicMock(spec=HttpTransport) + client = AgentMemoryClient(transport) + + with client: + pass + + transport.close.assert_called_once() + + +# ── Validation ──────────────────────────────────────────────────────────────── + + +class TestMemoryValidation: + + def test_add_memory_raises_for_empty_agent_id(self): + """add_memory raises AgentMemoryValidationError when agent_id is empty.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="agent_id"): + client.add_memory("", "user-1", "content") + + def test_add_memory_raises_for_empty_invoker_id(self): + """add_memory raises AgentMemoryValidationError when invoker_id is empty.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="invoker_id"): + client.add_memory("agent-1", "", "content") + + def test_add_memory_raises_for_empty_content(self): + """add_memory raises AgentMemoryValidationError when content is empty.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="content"): + client.add_memory("agent-1", "user-1", "") + + def test_get_memory_raises_for_empty_id(self): + """get_memory raises AgentMemoryValidationError when memory_id is empty.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="memory_id"): + client.get_memory("") + + def test_update_memory_raises_for_empty_id(self): + """update_memory raises AgentMemoryValidationError when memory_id is empty.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="memory_id"): + client.update_memory("", content="new content") + + def test_update_memory_raises_when_no_fields_provided(self): + """update_memory raises AgentMemoryValidationError when neither content nor metadata is provided.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="At least one"): + client.update_memory("uuid-123") + + def test_delete_memory_raises_for_empty_id(self): + """delete_memory raises AgentMemoryValidationError when memory_id is empty.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="memory_id"): + client.delete_memory("") + + def test_list_memories_raises_for_zero_limit(self): + """list_memories raises AgentMemoryValidationError when limit is 0.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="limit"): + client.list_memories(limit=0) + + def test_list_memories_raises_for_negative_offset(self): + """list_memories raises AgentMemoryValidationError when offset is negative.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="offset"): + client.list_memories(offset=-1) + + +class TestSearchMemoriesValidation: + + def test_raises_for_empty_agent_id(self): + """search_memories raises AgentMemoryValidationError when agent_id is empty.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="agent_id"): + client.search_memories("", "user-1", "what do I know about Python?") + + def test_raises_for_empty_invoker_id(self): + """search_memories raises AgentMemoryValidationError when invoker_id is empty.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="invoker_id"): + client.search_memories("agent-1", "", "what do I know about Python?") + + def test_raises_for_query_too_short(self): + """search_memories raises AgentMemoryValidationError when query has fewer than 5 chars.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="query"): + client.search_memories("agent-1", "user-1", "hi") + + def test_raises_for_query_too_long(self): + """search_memories raises AgentMemoryValidationError when query exceeds 5000 chars.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="query"): + client.search_memories("agent-1", "user-1", "x" * 5001) + + def test_raises_for_threshold_below_zero(self): + """search_memories raises AgentMemoryValidationError when threshold < 0.0.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="threshold"): + client.search_memories("a", "u", "valid query here", threshold=-0.1) + + def test_raises_for_threshold_above_one(self): + """search_memories raises AgentMemoryValidationError when threshold > 1.0.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="threshold"): + client.search_memories("a", "u", "valid query here", threshold=1.1) + + def test_raises_for_limit_zero(self): + """search_memories raises AgentMemoryValidationError when limit is 0.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="limit"): + client.search_memories("a", "u", "valid query here", limit=0) + + def test_raises_for_limit_above_fifty(self): + """search_memories raises AgentMemoryValidationError when limit exceeds 50.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="limit"): + client.search_memories("a", "u", "valid query here", limit=51) + + def test_boundary_values_are_accepted(self): + """search_memories accepts boundary values: 5-char query, threshold 0.0/1.0, limit 1/50.""" + client, mock_transport = _make_client() + mock_transport.post.return_value = {"value": []} + + client.search_memories("a", "u", "hello", threshold=0.0, limit=1) + client.search_memories("a", "u", "x" * 5000, threshold=1.0, limit=50) + + assert mock_transport.post.call_count == 2 + + +class TestMessageValidation: + + def test_add_message_raises_for_empty_agent_id(self): + """add_message raises AgentMemoryValidationError when agent_id is empty.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="agent_id"): + client.add_message("", "u", "grp", MessageRole.USER, "hi") + + def test_add_message_raises_for_empty_invoker_id(self): + """add_message raises AgentMemoryValidationError when invoker_id is empty.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="invoker_id"): + client.add_message("a", "", "grp", MessageRole.USER, "hi") + + def test_add_message_raises_for_empty_message_group(self): + """add_message raises AgentMemoryValidationError when message_group is empty.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="message_group"): + client.add_message("a", "u", "", MessageRole.USER, "hi") + + def test_add_message_raises_for_empty_content(self): + """add_message raises AgentMemoryValidationError when content is empty.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="content"): + client.add_message("a", "u", "grp", MessageRole.USER, "") + + def test_get_message_raises_for_empty_id(self): + """get_message raises AgentMemoryValidationError when message_id is empty.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="message_id"): + client.get_message("") + + def test_delete_message_raises_for_empty_id(self): + """delete_message raises AgentMemoryValidationError when message_id is empty.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="message_id"): + client.delete_message("") + + def test_list_messages_raises_for_zero_limit(self): + """list_messages raises AgentMemoryValidationError when limit is 0.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="limit"): + client.list_messages(limit=0) + + def test_list_messages_raises_for_negative_offset(self): + """list_messages raises AgentMemoryValidationError when offset is negative.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="offset"): + client.list_messages(offset=-1) + + +class TestRetentionConfigValidation: + + def test_update_raises_when_no_fields_provided(self): + """update_retention_config raises AgentMemoryValidationError when no fields are provided.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="At least one"): + client.update_retention_config() + + def test_update_raises_for_negative_message_days(self): + """update_retention_config raises AgentMemoryValidationError when message_days < 0.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="message_days"): + client.update_retention_config(message_days=-1) + + def test_update_raises_for_negative_memory_days(self): + """update_retention_config raises AgentMemoryValidationError when memory_days < 0.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="memory_days"): + client.update_retention_config(memory_days=-1) + + def test_update_raises_for_negative_usage_log_days(self): + """update_retention_config raises AgentMemoryValidationError when usage_log_days < 0.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="usage_log_days"): + client.update_retention_config(usage_log_days=-1) + + def test_update_accepts_zero_values(self): + """update_retention_config accepts 0 as a valid value (disables cleanup).""" + client, mock_transport = _make_client() + + client.update_retention_config(memory_days=0) + + mock_transport.patch.assert_called_once() + + +# ── FilterDefinition validation ─────────────────────────────────────────────────── + + +class TestFilterDefinitionValidation: + + def test_list_memories_raises_for_unsupported_target(self): + """list_memories raises AgentMemoryValidationError for an unknown target.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="target"): + client.list_memories( + filters=[FilterDefinition(target="agentID", contains="x")], + ) + + def test_list_memories_raises_for_empty_contains(self): + """list_memories raises AgentMemoryValidationError when contains is empty.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="contains"): + client.list_memories( + filters=[FilterDefinition(target="content", contains="")], + ) + + def test_list_messages_raises_for_unsupported_target(self): + """list_messages raises AgentMemoryValidationError for an unknown target.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="target"): + client.list_messages( + filters=[FilterDefinition(target="role", contains="x")], + ) + + def test_list_messages_raises_for_empty_contains(self): + """list_messages raises AgentMemoryValidationError when contains is empty.""" + client, _ = _make_client() + with pytest.raises(AgentMemoryValidationError, match="contains"): + client.list_messages( + filters=[FilterDefinition(target="metadata", contains="")], + ) diff --git a/tests/agent_memory/unit/test_config.py b/tests/agent_memory/unit/test_config.py new file mode 100644 index 0000000..fc94745 --- /dev/null +++ b/tests/agent_memory/unit/test_config.py @@ -0,0 +1,174 @@ +"""Unit tests for AgentMemoryConfig, BindingData, and _load_config_from_env.""" + +from unittest.mock import patch + +import pytest + +from sap_cloud_sdk.agent_memory.config import ( + AgentMemoryConfig, + BindingData, + _load_config_from_env, +) +from sap_cloud_sdk.agent_memory.exceptions import AgentMemoryConfigError + + +# ── AgentMemoryConfig ───────────────────────────────────────────────────────── + + +class TestAgentMemoryConfig: + def test_raises_when_base_url_empty(self): + """AgentMemoryConfig rejects an empty base_url.""" + with pytest.raises(AgentMemoryConfigError, match="base_url"): + AgentMemoryConfig(base_url="") + + def test_optional_fields_default_to_none(self): + """token_url, client_id, and client_secret default to None.""" + config = AgentMemoryConfig(base_url="http://localhost:8080") + assert config.token_url is None + assert config.client_id is None + assert config.client_secret is None + + def test_timeout_default(self): + """Default timeout is 30.0 seconds.""" + config = AgentMemoryConfig(base_url="http://localhost:8080") + assert config.timeout == 30.0 + + +# ── BindingData ─────────────────────────────────────────────────────────────── + + +class TestBindingData: + def test_validate_raises_when_all_fields_empty(self): + """validate() raises AgentMemoryConfigError when all fields are empty.""" + with pytest.raises(AgentMemoryConfigError, match="missing required fields"): + BindingData().validate() + + def test_validate_raises_when_some_fields_empty(self): + """validate() raises when only some fields are populated.""" + binding = BindingData(application_url="https://example.com") + with pytest.raises(AgentMemoryConfigError, match="missing required fields"): + binding.validate() + + def test_validate_passes_when_all_fields_set(self): + """validate() does not raise when all required fields are populated.""" + binding = BindingData( + application_url="https://example.com", + uaa_url="https://auth.example.com", + uaa_clientid="client-id", + uaa_clientsecret="client-secret", + ) + binding.validate() # should not raise + + def test_extract_config_derives_token_url(self): + """extract_config() appends /oauth/token to uaa_url.""" + binding = BindingData( + application_url="https://memory.example.com", + uaa_url="https://auth.example.com", + uaa_clientid="cid", + uaa_clientsecret="csec", + ) + config = binding.extract_config() + assert config.token_url == "https://auth.example.com/oauth/token" + + def test_extract_config_strips_trailing_slash_from_uaa_url(self): + """extract_config() strips a trailing slash before appending /oauth/token.""" + binding = BindingData( + application_url="https://memory.example.com", + uaa_url="https://auth.example.com/", + uaa_clientid="cid", + uaa_clientsecret="csec", + ) + config = binding.extract_config() + assert config.token_url == "https://auth.example.com/oauth/token" + + def test_extract_config_maps_all_fields(self): + """extract_config() maps all binding fields to AgentMemoryConfig.""" + binding = BindingData( + application_url="https://memory.example.com", + uaa_url="https://auth.example.com", + uaa_clientid="my-client", + uaa_clientsecret="my-secret", + ) + config = binding.extract_config() + assert config.base_url == "https://memory.example.com" + assert config.client_id == "my-client" + assert config.client_secret == "my-secret" + + +# ── _load_config_from_env ───────────────────────────────────────────────────── + +_MOUNT_LOADER = "sap_cloud_sdk.core.secret_resolver.resolver._load_from_mount" + +_ENV_VARS = { + "CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_APPLICATION_URL": "https://memory.example.com", + "CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA_URL": "https://auth.example.com", + "CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA_CLIENTID": "env-client", + "CLOUD_SDK_CFG_HANA_AGENT_MEMORY_DEFAULT_UAA_CLIENTSECRET": "env-secret", +} + + +def _fill_binding(_base_volume_mount, _module, _instance, target) -> None: + target.application_url = "https://memory.example.com" + target.uaa_url = "https://auth.example.com" + target.uaa_clientid = "resolved-client" + target.uaa_clientsecret = "resolved-secret" + + +class TestLoadConfigFromEnv: + def test_success_from_mount(self): + """_load_config_from_env() returns a valid AgentMemoryConfig when mount succeeds.""" + with patch(_MOUNT_LOADER, side_effect=_fill_binding): + config = _load_config_from_env() + + assert config.base_url == "https://memory.example.com" + assert config.token_url == "https://auth.example.com/oauth/token" + assert config.client_id == "resolved-client" + assert config.client_secret == "resolved-secret" + + def test_calls_mount_loader_with_correct_arguments(self): + """_load_config_from_env() calls _load_from_mount with the correct path/module/instance.""" + with patch(_MOUNT_LOADER, side_effect=_fill_binding) as mock_mount: + _load_config_from_env() + + mock_mount.assert_called_once() + args = mock_mount.call_args[0] + assert args[0] == "/etc/secrets/appfnd" + assert args[1] == "hana-agent-memory" + assert args[2] == "default" + + def test_falls_back_to_env_when_mount_fails(self, monkeypatch): + """_load_config_from_env() reads env vars when the mount path is unavailable.""" + for var, val in _ENV_VARS.items(): + monkeypatch.setenv(var, val) + + with patch(_MOUNT_LOADER, side_effect=FileNotFoundError("no mount")): + config = _load_config_from_env() + + assert config.base_url == "https://memory.example.com" + assert config.client_id == "env-client" + + def test_raises_config_error_when_mount_and_env_both_fail(self, monkeypatch): + """_load_config_from_env() raises AgentMemoryConfigError when both mount and env vars are absent.""" + for var in _ENV_VARS: + monkeypatch.delenv(var, raising=False) + + with patch(_MOUNT_LOADER, side_effect=FileNotFoundError("no mount")): + with pytest.raises( + AgentMemoryConfigError, match="Missing required environment variables" + ): + _load_config_from_env() + + def test_raises_config_error_when_mount_binding_incomplete_and_env_missing( + self, monkeypatch + ): + """_load_config_from_env() raises AgentMemoryConfigError when mount gives partial data and env is absent.""" + for var in _ENV_VARS: + monkeypatch.delenv(var, raising=False) + + def incomplete_fill(_bvm, _mod, _inst, target): + target.application_url = "https://example.com" + # uaa_url/uaa_clientid/uaa_clientsecret remain empty → validate() raises + + with patch(_MOUNT_LOADER, side_effect=incomplete_fill): + with pytest.raises(AgentMemoryConfigError): + _load_config_from_env() diff --git a/tests/agent_memory/unit/test_exceptions.py b/tests/agent_memory/unit/test_exceptions.py new file mode 100644 index 0000000..b2a8a23 --- /dev/null +++ b/tests/agent_memory/unit/test_exceptions.py @@ -0,0 +1,93 @@ +import pytest + +from sap_cloud_sdk.agent_memory.exceptions import ( + AgentMemoryConfigError, + AgentMemoryError, + AgentMemoryHttpError, + AgentMemoryNotFoundError, +) + + +class TestAgentMemoryError: + + def test_is_base_exception(self): + assert issubclass(AgentMemoryError, Exception) + + def test_message(self): + exc = AgentMemoryError("base error") + assert str(exc) == "base error" + + def test_can_be_raised_and_caught(self): + with pytest.raises(AgentMemoryError, match="test error"): + raise AgentMemoryError("test error") + + +class TestAgentMemoryConfigError: + + def test_is_agent_memory_error(self): + assert issubclass(AgentMemoryConfigError, AgentMemoryError) + + def test_message(self): + exc = AgentMemoryConfigError("missing base_url") + assert str(exc) == "missing base_url" + + def test_caught_as_base_error(self): + with pytest.raises(AgentMemoryError): + raise AgentMemoryConfigError("config problem") + + +class TestAgentMemoryHttpError: + + def test_is_agent_memory_error(self): + assert issubclass(AgentMemoryHttpError, AgentMemoryError) + + def test_message(self): + exc = AgentMemoryHttpError("request failed") + assert str(exc) == "request failed" + + def test_status_code_defaults_to_none(self): + exc = AgentMemoryHttpError("error") + assert exc.status_code is None + + def test_response_text_defaults_to_none(self): + exc = AgentMemoryHttpError("error") + assert exc.response_text is None + + def test_status_code_is_set(self): + exc = AgentMemoryHttpError("error", status_code=500) + assert exc.status_code == 500 + + def test_response_text_is_set(self): + exc = AgentMemoryHttpError("error", response_text="Internal Server Error") + assert exc.response_text == "Internal Server Error" + + def test_both_attrs_set(self): + exc = AgentMemoryHttpError("error", status_code=422, response_text="Bad input") + assert exc.status_code == 422 + assert exc.response_text == "Bad input" + + def test_caught_as_base_error(self): + with pytest.raises(AgentMemoryError): + raise AgentMemoryHttpError("http failure", status_code=503) + + +class TestAgentMemoryNotFoundError: + + def test_is_http_error(self): + assert issubclass(AgentMemoryNotFoundError, AgentMemoryHttpError) + + def test_is_agent_memory_error(self): + assert issubclass(AgentMemoryNotFoundError, AgentMemoryError) + + def test_message_and_status_code(self): + exc = AgentMemoryNotFoundError("not found", status_code=404) + assert str(exc) == "not found" + assert exc.status_code == 404 + + def test_caught_as_http_error(self): + with pytest.raises(AgentMemoryHttpError): + raise AgentMemoryNotFoundError("not found", status_code=404) + + def test_caught_as_base_error(self): + with pytest.raises(AgentMemoryError): + raise AgentMemoryNotFoundError("not found") diff --git a/tests/agent_memory/unit/test_http_transport.py b/tests/agent_memory/unit/test_http_transport.py new file mode 100644 index 0000000..282abe3 --- /dev/null +++ b/tests/agent_memory/unit/test_http_transport.py @@ -0,0 +1,273 @@ +"""Unit tests for HttpTransport.""" + +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch + +import pytest + +from sap_cloud_sdk.agent_memory._http_transport import ( + HttpTransport, + _TOKEN_EXPIRY_BUFFER_SECONDS, +) +from sap_cloud_sdk.agent_memory.config import AgentMemoryConfig +from sap_cloud_sdk.agent_memory.exceptions import AgentMemoryHttpError, AgentMemoryNotFoundError + + +def _config(with_auth: bool = True) -> AgentMemoryConfig: + if with_auth: + return AgentMemoryConfig( + base_url="http://localhost:8080", + token_url="http://localhost:8080/oauth/token", + client_id="client-id", + client_secret="client-secret", + ) + return AgentMemoryConfig(base_url="http://localhost:8080") + + +def _mock_response( + status_code: int, + json_data: dict | None = None, + text: str = "", +) -> MagicMock: + response = MagicMock() + response.status_code = status_code + response.ok = 200 <= status_code < 300 + response.content = b"content" if json_data is not None else b"" + response.text = text + response.json.return_value = json_data or {} + return response + + +# ── No-auth local dev mode ──────────────────────────────────────────────────── + + +class TestNoAuthMode: + + def test_sends_request_without_authorization_header(self): + """No-auth mode does not send an Authorization header.""" + transport = HttpTransport(_config(with_auth=False)) + mock_session = MagicMock() + transport._plain_session = mock_session + mock_session.request.return_value = _mock_response(200, {"data": []}) + + transport.get("/test") + + _, kwargs = mock_session.request.call_args + assert "Authorization" not in kwargs.get("headers", {}) + + def test_uses_plain_session_when_no_token_url(self): + """No-auth mode uses a plain requests.Session, not OAuth2Session.""" + with patch("sap_cloud_sdk.agent_memory._http_transport.requests") as mock_requests: + mock_session = MagicMock() + mock_requests.Session.return_value = mock_session + mock_session.request.return_value = _mock_response(200, {}) + + transport = HttpTransport(_config(with_auth=False)) + transport.get("/test") + + mock_requests.Session.assert_called_once() + + +# ── Token acquisition ───────────────────────────────────────────────────────── + + +class TestTokenAcquisition: + + def test_token_is_fetched_and_cached(self): + """fetch_token is called only once across multiple requests.""" + with patch( + "sap_cloud_sdk.agent_memory._http_transport.OAuth2Session" + ) as MockOAuth, patch( + "sap_cloud_sdk.agent_memory._http_transport.BackendApplicationClient" + ): + mock_oauth = MagicMock() + MockOAuth.return_value = mock_oauth + mock_oauth.fetch_token.return_value = { + "access_token": "my-token", + "expires_in": 3600, + } + mock_oauth.request.return_value = _mock_response(200, {"data": []}) + + transport = HttpTransport(_config()) + transport.get("/test") + transport.get("/test") + + assert mock_oauth.fetch_token.call_count == 1 + + def test_expired_token_triggers_refetch(self): + """Expired token causes a new fetch_token call.""" + with patch( + "sap_cloud_sdk.agent_memory._http_transport.OAuth2Session" + ) as MockOAuth, patch( + "sap_cloud_sdk.agent_memory._http_transport.BackendApplicationClient" + ): + mock_oauth = MagicMock() + MockOAuth.return_value = mock_oauth + mock_oauth.fetch_token.return_value = { + "access_token": "token", + "expires_in": 3600, + } + mock_oauth.request.return_value = _mock_response(200, {}) + + transport = HttpTransport(_config()) + transport._token_expires_at = datetime.now() - timedelta(seconds=1) + transport.get("/test") + transport.get("/test") + + # Two fetches: one for each request since we started with an expired timestamp + assert mock_oauth.fetch_token.call_count >= 1 + + def test_token_expiry_uses_buffer(self): + """Token expiry is set with _TOKEN_EXPIRY_BUFFER_SECONDS subtracted.""" + with patch( + "sap_cloud_sdk.agent_memory._http_transport.OAuth2Session" + ) as MockOAuth, patch( + "sap_cloud_sdk.agent_memory._http_transport.BackendApplicationClient" + ): + mock_oauth = MagicMock() + MockOAuth.return_value = mock_oauth + mock_oauth.fetch_token.return_value = { + "access_token": "tok", + "expires_in": 3600, + } + mock_oauth.request.return_value = _mock_response(200, {}) + + transport = HttpTransport(_config()) + transport.get("/test") + + expected_max = datetime.now() + timedelta( + seconds=3600 - _TOKEN_EXPIRY_BUFFER_SECONDS + 5 + ) + assert transport._token_expires_at is not None + assert transport._token_expires_at < expected_max + + def test_token_fetch_failure_raises_http_error(self): + """Failed token fetch raises AgentMemoryHttpError.""" + with patch( + "sap_cloud_sdk.agent_memory._http_transport.OAuth2Session" + ) as MockOAuth, patch( + "sap_cloud_sdk.agent_memory._http_transport.BackendApplicationClient" + ): + mock_oauth = MagicMock() + MockOAuth.return_value = mock_oauth + mock_oauth.fetch_token.side_effect = Exception("connection refused") + + transport = HttpTransport(_config()) + with pytest.raises(AgentMemoryHttpError, match="OAuth2 token"): + transport.get("/test") + + +# ── HTTP methods ────────────────────────────────────────────────────────────── + + +class TestHttpMethods: + + def _transport_no_auth(self) -> tuple[HttpTransport, MagicMock]: + transport = HttpTransport(_config(with_auth=False)) + mock_session = MagicMock() + transport._plain_session = mock_session + return transport, mock_session + + def test_get_sends_get_request(self): + """GET request is constructed with the correct method and URL.""" + transport, mock_session = self._transport_no_auth() + mock_session.request.return_value = _mock_response(200, {"key": "value"}) + + result = transport.get("/memories", params={"$top": "10"}) + + mock_session.request.assert_called_once() + call_args = mock_session.request.call_args + assert call_args[0][0] == "GET" + assert call_args[0][1].startswith("http://localhost:8080/memories") + assert "%24top=10" in call_args[0][1] + assert result == {"key": "value"} + + def test_post_sends_post_request(self): + """POST request is constructed with the correct method.""" + transport, mock_session = self._transport_no_auth() + mock_session.request.return_value = _mock_response(201, {"id": "new-memory"}) + + result = transport.post("/memories", json={"agentID": "a"}) + + assert mock_session.request.call_args[0][0] == "POST" + assert result == {"id": "new-memory"} + + def test_patch_sends_patch_request(self): + """PATCH request is constructed with the correct method.""" + transport, mock_session = self._transport_no_auth() + mock_session.request.return_value = _mock_response(200, {"id": "mem-1"}) + + result = transport.patch("/memories(mem-1)", json={"content": "updated"}) + + assert mock_session.request.call_args[0][0] == "PATCH" + assert result == {"id": "mem-1"} + + def test_delete_sends_delete_and_returns_none(self): + """DELETE sends correct method and returns None.""" + transport, mock_session = self._transport_no_auth() + mock_session.request.return_value = _mock_response(204) + + result = transport.delete("/memories/abc") + + assert mock_session.request.call_args[0][0] == "DELETE" + assert result is None + + def test_404_raises_not_found_error(self): + """404 responses raise AgentMemoryNotFoundError.""" + transport, mock_session = self._transport_no_auth() + mock_resp = _mock_response(404, text="Not Found") + mock_resp.content = b"Not Found" + mock_session.request.return_value = mock_resp + + with pytest.raises(AgentMemoryNotFoundError) as exc_info: + transport.get("/memories/nonexistent") + + assert exc_info.value.status_code == 404 + + def test_server_error_raises_http_error(self): + """500 responses raise AgentMemoryHttpError with the status code.""" + transport, mock_session = self._transport_no_auth() + mock_resp = _mock_response(500, text="Internal Server Error") + mock_resp.content = b"Internal Server Error" + mock_session.request.return_value = mock_resp + + with pytest.raises(AgentMemoryHttpError) as exc_info: + transport.get("/memories") + + assert exc_info.value.status_code == 500 + + +# ── Close ───────────────────────────────────────────────────────────────────── + + +class TestClose: + + def test_close_clears_oauth_session(self): + """close() clears the OAuth session.""" + with patch( + "sap_cloud_sdk.agent_memory._http_transport.OAuth2Session" + ) as MockOAuth, patch( + "sap_cloud_sdk.agent_memory._http_transport.BackendApplicationClient" + ): + mock_oauth = MagicMock() + MockOAuth.return_value = mock_oauth + mock_oauth.fetch_token.return_value = {"access_token": "tok", "expires_in": 3600} + mock_oauth.request.return_value = _mock_response(200, {}) + + transport = HttpTransport(_config()) + transport.get("/test") + transport.close() + + mock_oauth.close.assert_called_once() + assert transport._oauth is None + + def test_close_clears_plain_session(self): + """close() clears the plain session in no-auth mode.""" + transport = HttpTransport(_config(with_auth=False)) + mock_session = MagicMock() + transport._plain_session = mock_session + + transport.close() + + mock_session.close.assert_called_once() + assert transport._plain_session is None diff --git a/tests/agent_memory/unit/test_models.py b/tests/agent_memory/unit/test_models.py new file mode 100644 index 0000000..aa159e9 --- /dev/null +++ b/tests/agent_memory/unit/test_models.py @@ -0,0 +1,359 @@ +"""Unit tests for Agent Memory data models (v1 API).""" + +from sap_cloud_sdk.agent_memory._models import ( + Memory, + Message, + MessageRole, + RetentionConfig, + SearchResult, +) + + +# ── Memory ──────────────────────────────────────────────────────────────────── + + +class TestMemory: + + def test_from_dict_maps_known_fields(self): + """All known API fields are mapped to model attributes.""" + data = { + "id": "mem-1", + "agentID": "agent-a", + "invokerID": "user-b", + "content": "The user prefers dark mode.", + "metadata": {"key": "val"}, + "createTimestamp": "2025-01-01T00:00:00Z", + "updateTimestamp": "2025-01-02T00:00:00Z", + } + memory = Memory.from_dict(data) + + assert memory.id == "mem-1" + assert memory.agent_id == "agent-a" + assert memory.invoker_id == "user-b" + assert memory.content == "The user prefers dark mode." + assert memory.metadata == {"key": "val"} + assert memory.create_timestamp == "2025-01-01T00:00:00Z" + assert memory.update_timestamp == "2025-01-02T00:00:00Z" + + def test_from_dict_optional_fields_default_to_none(self): + """Optional fields default to None when absent.""" + data = {"id": "m1", "agentID": "a", "invokerID": "u", "content": "hello"} + memory = Memory.from_dict(data) + assert memory.metadata is None + assert memory.create_timestamp is None + assert memory.update_timestamp is None + + def test_from_dict_parses_string_metadata(self): + """String metadata is parsed as JSON.""" + data = {"id": "m1", "agentID": "a", "invokerID": "u", "content": "x", + "metadata": '{"key": "val"}'} + memory = Memory.from_dict(data) + assert memory.metadata == {"key": "val"} + + def test_from_dict_handles_invalid_json_metadata(self): + """Invalid JSON metadata is wrapped in a raw dict.""" + data = {"id": "m1", "agentID": "a", "invokerID": "u", "content": "x", + "metadata": "not-json"} + memory = Memory.from_dict(data) + assert memory.metadata == {"raw": "not-json"} + + def test_from_dict_empty_dict_uses_defaults(self): + """An empty dict produces safe defaults.""" + memory = Memory.from_dict({}) + assert memory.id == "" + assert memory.content == "" + assert memory.metadata is None + + def test_to_dict_shape(self): + """to_dict outputs all fields with camelCase keys.""" + memory = Memory( + id="m1", agent_id="a", invoker_id="u", content="hello", + create_timestamp="2025-01-01T00:00:00Z", + update_timestamp="2025-01-02T00:00:00Z", + ) + d = memory.to_dict() + assert d == { + "id": "m1", + "agentID": "a", + "invokerID": "u", + "content": "hello", + "createTimestamp": "2025-01-01T00:00:00Z", + "updateTimestamp": "2025-01-02T00:00:00Z", + } + + def test_to_dict_includes_metadata_when_set(self): + """to_dict includes metadata when set.""" + memory = Memory( + id="m1", agent_id="a", invoker_id="u", content="hello", + metadata={"k": "v"}, + ) + d = memory.to_dict() + assert d["metadata"] == {"k": "v"} + + def test_to_dict_omits_metadata_when_none(self): + """to_dict omits metadata when None.""" + memory = Memory(id="m1", agent_id="a", invoker_id="u", content="hello") + d = memory.to_dict() + assert "metadata" not in d + + def test_round_trip(self): + """from_dict(to_dict(m)) reproduces the original object.""" + data = { + "id": "m1", + "agentID": "agent-a", + "invokerID": "user-b", + "content": "The user prefers dark mode.", + "metadata": {"key": "val"}, + "createTimestamp": "2025-01-01T00:00:00Z", + "updateTimestamp": "2025-01-02T00:00:00Z", + } + memory = Memory.from_dict(data) + assert Memory.from_dict(memory.to_dict()) == memory + + +# ── SearchResult ────────────────────────────────────────────────────────────── + + +class TestSearchResult: + + def test_from_dict_maps_all_fields(self): + """All known fields including similarity score are mapped correctly.""" + data = { + "id": "m1", + "agentID": "agent-a", + "invokerID": "user-b", + "content": "user preference", + "similarity": 0.92, + "metadata": {"source": "meta"}, + "createTimestamp": "2025-01-01T00:00:00Z", + "updateTimestamp": "2025-01-02T00:00:00Z", + } + result = SearchResult.from_dict(data) + + assert result.id == "m1" + assert result.agent_id == "agent-a" + assert result.similarity == 0.92 + assert result.metadata == {"source": "meta"} + assert result.create_timestamp == "2025-01-01T00:00:00Z" + assert result.update_timestamp == "2025-01-02T00:00:00Z" + + def test_from_dict_similarity_defaults_to_none(self): + """Similarity defaults to None when absent.""" + data = {"id": "m1", "agentID": "a", "invokerID": "u", "content": "x"} + result = SearchResult.from_dict(data) + assert result.similarity is None + + def test_to_dict_includes_similarity(self): + """to_dict includes the similarity score.""" + result = SearchResult( + id="m1", agent_id="a", invoker_id="u", content="x", + similarity=0.85, + ) + d = result.to_dict() + assert d["similarity"] == 0.85 + + def test_round_trip(self): + """from_dict(to_dict(r)) reproduces the original object.""" + data = { + "id": "m1", + "agentID": "a", + "invokerID": "u", + "content": "user preference", + "similarity": 0.92, + "metadata": {"source": "meta"}, + "createTimestamp": "2025-01-01T00:00:00Z", + "updateTimestamp": "2025-01-02T00:00:00Z", + } + result = SearchResult.from_dict(data) + assert SearchResult.from_dict(result.to_dict()) == result + + +# ── Message ─────────────────────────────────────────────────────────────────── + + +class TestMessage: + + def test_from_dict_maps_known_fields(self): + """All known message fields are mapped correctly.""" + data = { + "id": "msg-1", + "agentID": "agent-a", + "invokerID": "user-b", + "messageGroup": "conv-1", + "role": "USER", + "content": "Hello!", + "metadata": {"key": "val"}, + "createTimestamp": "2025-01-01T00:00:00Z", + } + msg = Message.from_dict(data) + + assert msg.id == "msg-1" + assert msg.agent_id == "agent-a" + assert msg.invoker_id == "user-b" + assert msg.message_group == "conv-1" + assert msg.role == MessageRole.USER + assert msg.content == "Hello!" + assert msg.metadata == {"key": "val"} + assert msg.create_timestamp == "2025-01-01T00:00:00Z" + + def test_from_dict_parses_string_metadata(self): + """String metadata is parsed as JSON.""" + data = { + "id": "msg-1", "agentID": "a", "invokerID": "u", + "messageGroup": "g", "role": "USER", "content": "hi", + "metadata": '{"key": "val"}', + } + msg = Message.from_dict(data) + assert msg.metadata == {"key": "val"} + + def test_from_dict_handles_invalid_json_metadata(self): + """Invalid JSON metadata is wrapped in a raw dict.""" + data = { + "id": "msg-1", "agentID": "a", "invokerID": "u", + "messageGroup": "g", "role": "USER", "content": "hi", + "metadata": "not-json", + } + msg = Message.from_dict(data) + assert msg.metadata == {"raw": "not-json"} + + def test_from_dict_handles_missing_metadata(self): + """Missing metadata defaults to None.""" + data = { + "id": "msg-1", "agentID": "a", "invokerID": "u", + "messageGroup": "g", "role": "USER", "content": "hi", + } + msg = Message.from_dict(data) + assert msg.metadata is None + + def test_from_dict_empty_dict_uses_defaults(self): + """An empty dict produces safe defaults.""" + msg = Message.from_dict({}) + assert msg.id == "" + assert msg.message_group == "" + assert msg.role is None + assert msg.metadata is None + + def test_to_dict_shape(self): + """to_dict outputs all fields with camelCase keys.""" + msg = Message( + id="msg-1", agent_id="a", invoker_id="u", + message_group="g", role=MessageRole.USER, content="hi", + create_timestamp="2025-01-01T00:00:00Z", + ) + d = msg.to_dict() + assert d == { + "id": "msg-1", + "agentID": "a", + "invokerID": "u", + "messageGroup": "g", + "role": "USER", + "content": "hi", + "createTimestamp": "2025-01-01T00:00:00Z", + } + + def test_to_dict_includes_metadata_when_set(self): + """to_dict includes metadata when set.""" + msg = Message( + id="msg-1", agent_id="a", invoker_id="u", + message_group="g", role=MessageRole.USER, content="hi", + metadata={"key": "val"}, + ) + d = msg.to_dict() + assert d["metadata"] == {"key": "val"} + + def test_to_dict_omits_none_role_and_metadata(self): + """to_dict omits role and metadata when None.""" + msg = Message( + id="msg-1", agent_id="a", invoker_id="u", + message_group="g", content="hi", + ) + d = msg.to_dict() + assert "role" not in d + assert "metadata" not in d + + def test_round_trip(self): + """from_dict(to_dict(m)) reproduces the original object.""" + data = { + "id": "msg-1", + "agentID": "agent-a", + "invokerID": "user-b", + "messageGroup": "conv-1", + "role": "USER", + "content": "Hello!", + "metadata": {"key": "val"}, + "createTimestamp": "2025-01-01T00:00:00Z", + } + msg = Message.from_dict(data) + assert Message.from_dict(msg.to_dict()) == msg + + +# ── Enums ───────────────────────────────────────────────────────────────────── + + +class TestEnums: + + def test_message_role_values(self): + """MessageRole enum has the expected members.""" + assert MessageRole.USER == "USER" + assert MessageRole.ASSISTANT == "ASSISTANT" + assert MessageRole.SYSTEM == "SYSTEM" + assert MessageRole.TOOL == "TOOL" + + def test_message_role_is_str(self): + """MessageRole values can be used as strings.""" + assert isinstance(MessageRole.USER, str) + assert MessageRole.USER == "USER" + + +# ── RetentionConfig ───────────────────────────────────────────────────────────── + + +class TestRetentionConfig: + + def test_from_dict_maps_known_fields(self): + """All known retention config fields are mapped correctly.""" + data = { + "id": 1, + "messageDays": 30, + "memoryDays": 90, + "usageLogDays": 180, + "createTimestamp": "2025-01-01T00:00:00Z", + "updateTimestamp": "2025-01-02T00:00:00Z", + } + rc = RetentionConfig.from_dict(data) + + assert rc.id == 1 + assert rc.message_days == 30 + assert rc.memory_days == 90 + assert rc.usage_log_days == 180 + assert rc.create_timestamp == "2025-01-01T00:00:00Z" + assert rc.update_timestamp == "2025-01-02T00:00:00Z" + + def test_from_dict_optional_fields_default_to_none(self): + """Optional fields default to None when absent.""" + data = {"id": 1} + rc = RetentionConfig.from_dict(data) + assert rc.message_days is None + assert rc.memory_days is None + assert rc.usage_log_days is None + assert rc.create_timestamp is None + assert rc.update_timestamp is None + + def test_from_dict_empty_dict_uses_defaults(self): + """An empty dict produces safe defaults.""" + rc = RetentionConfig.from_dict({}) + assert rc.id is None + assert rc.message_days is None + + def test_round_trip(self): + """from_dict(to_dict(rc)) reproduces the original object.""" + data = { + "id": 1, + "messageDays": 30, + "memoryDays": 90, + "usageLogDays": 180, + "createTimestamp": "2025-01-01T00:00:00Z", + "updateTimestamp": "2025-01-02T00:00:00Z", + } + rc = RetentionConfig.from_dict(data) + assert RetentionConfig.from_dict(rc.to_dict()) == rc diff --git a/tests/agent_memory/unit/test_odata.py b/tests/agent_memory/unit/test_odata.py new file mode 100644 index 0000000..deda8f0 --- /dev/null +++ b/tests/agent_memory/unit/test_odata.py @@ -0,0 +1,311 @@ +"""Unit tests for OData query-building utilities.""" + +from sap_cloud_sdk.agent_memory.utils._odata import ( + FilterDefinition, + _escape_odata_string, + build_contains_clauses, + build_list_params, + build_memory_filter, + build_message_filter, + extract_value_and_count, +) + + +# ── _escape_odata_string ────────────────────────────────────────────────────── + + +class TestEscapeOdataString: + + def test_single_quote_is_doubled(self): + """Single quotes are escaped as '' per OData 4.01.""" + assert _escape_odata_string("it's") == "it''s" + + def test_multiple_single_quotes(self): + """All single quotes in the value are escaped.""" + assert _escape_odata_string("a'b'c") == "a''b''c" + + def test_no_quotes_unchanged(self): + """Values without single quotes are returned unchanged.""" + assert _escape_odata_string("john") == "john" + + def test_empty_string(self): + """Empty string is returned unchanged.""" + assert _escape_odata_string("") == "" + + +# ── build_contains_clauses ──────────────────────────────────────────────────── + + +class TestBuildContainsClauses: + + def test_single_clause(self): + """A single clause produces one contains() expression.""" + result = build_contains_clauses([FilterDefinition(target="metadata", contains="john")]) + assert result == ["contains(metadata, 'john')"] + + def test_multiple_clauses(self): + """Multiple clauses produce one expression per clause.""" + result = build_contains_clauses([ + FilterDefinition(target="metadata", contains="john"), + FilterDefinition(target="content", contains="user prefers"), + ]) + assert result == [ + "contains(metadata, 'john')", + "contains(content, 'user prefers')", + ] + + def test_single_quote_in_value_is_escaped(self): + """Single quotes inside the contains value are escaped.""" + result = build_contains_clauses([FilterDefinition(target="content", contains="it's")]) + assert result == ["contains(content, 'it''s')"] + + def test_empty_clause_list_returns_empty_list(self): + """An empty input list returns an empty list.""" + assert build_contains_clauses([]) == [] + + +# ── build_memory_filter ────────────────────────────────────────────────────── + + +class TestBuildMemoryFilter: + + def test_agent_id_only(self): + """Single agent_id produces a simple eq filter.""" + result = build_memory_filter(agent_id="agent-a") + assert result == "agentID eq 'agent-a'" + + def test_invoker_id_only(self): + """Single invoker_id produces a simple eq filter.""" + result = build_memory_filter(invoker_id="user-b") + assert result == "invokerID eq 'user-b'" + + def test_both_filters(self): + """Both filters are combined with 'and'.""" + result = build_memory_filter(agent_id="a", invoker_id="u") + assert result is not None + assert "agentID eq 'a'" in result + assert "invokerID eq 'u'" in result + assert " and " in result + + def test_no_filters_returns_none(self): + """No arguments returns None.""" + result = build_memory_filter() + assert result is None + + def test_filter_clauses_appended(self): + """FilterDefinition objects are appended as contains() expressions.""" + result = build_memory_filter( + agent_id="a", + filter_clauses=[FilterDefinition(target="content", contains="dark mode")], + ) + assert result is not None + assert "agentID eq 'a'" in result + assert "contains(content, 'dark mode')" in result + assert " and " in result + + def test_agent_id_with_single_quote_is_escaped(self): + """Single quotes in agent_id are escaped to prevent malformed OData.""" + result = build_memory_filter(agent_id="bob's-agent") + assert result == "agentID eq 'bob''s-agent'" + + def test_invoker_id_with_single_quote_is_escaped(self): + """Single quotes in invoker_id are escaped.""" + result = build_memory_filter(invoker_id="o'brien") + assert result == "invokerID eq 'o''brien'" + + +# ── build_message_filter ───────────────────────────────────────────────────── + + +class TestBuildMessageFilter: + + def test_all_filters(self): + """All four message filters are combined.""" + result = build_message_filter( + agent_id="a", invoker_id="u", + message_group="g", role="USER", + ) + assert result is not None + assert "agentID eq 'a'" in result + assert "invokerID eq 'u'" in result + assert "messageGroup eq 'g'" in result + assert "role eq 'USER'" in result + + def test_single_filter(self): + """A single filter produces a simple eq expression.""" + result = build_message_filter(role="ASSISTANT") + assert result == "role eq 'ASSISTANT'" + + def test_no_filters_returns_none(self): + """No arguments returns None.""" + result = build_message_filter() + assert result is None + + def test_filter_clauses_appended(self): + """FilterDefinition objects are appended as contains() expressions.""" + result = build_message_filter( + agent_id="a", + filter_clauses=[FilterDefinition(target="content", contains="invoice")], + ) + assert result is not None + assert "agentID eq 'a'" in result + assert "contains(content, 'invoice')" in result + assert " and " in result + + def test_agent_id_with_single_quote_is_escaped(self): + """Single quotes in agent_id are escaped to prevent malformed OData.""" + result = build_message_filter(agent_id="bob's-agent") + assert result == "agentID eq 'bob''s-agent'" + + def test_invoker_id_with_single_quote_is_escaped(self): + """Single quotes in invoker_id are escaped.""" + result = build_message_filter(invoker_id="user-x'y") + assert result == "invokerID eq 'user-x''y'" + + def test_message_group_with_single_quote_is_escaped(self): + """Single quotes in message_group are escaped.""" + result = build_message_filter(message_group="group'1") + assert result == "messageGroup eq 'group''1'" + + def test_role_with_single_quote_is_escaped(self): + """Single quotes in role are escaped.""" + result = build_message_filter(role="US'ER") + assert result == "role eq 'US''ER'" + + +# ── build_list_params ──────────────────────────────────────────────────────── + + +class TestBuildListParams: + + def test_filter_param(self): + """filter_expr is output as $filter.""" + params = build_list_params(filter_expr="agentID eq 'a'") + assert params["$filter"] == "agentID eq 'a'" + + def test_search_param(self): + """search is output as $search.""" + params = build_list_params(search="dark mode") + assert params["$search"] == "dark mode" + + def test_select_param_injects_required_fields(self): + """$select always includes agentID and invokerID.""" + params = build_list_params(select="id,content") + fields = set(params["$select"].split(",")) + assert "id" in fields + assert "content" in fields + assert "agentID" in fields + assert "invokerID" in fields + + def test_select_param_no_duplicates(self): + """$select does not duplicate agentID/invokerID if already present.""" + params = build_list_params(select="agentID,invokerID,content") + fields = params["$select"].split(",") + assert fields.count("agentID") == 1 + assert fields.count("invokerID") == 1 + + def test_top_param(self): + """top is output as $top string.""" + params = build_list_params(top=10) + assert params["$top"] == "10" + + def test_skip_param(self): + """skip is output as $skip string.""" + params = build_list_params(skip=20) + assert params["$skip"] == "20" + + def test_order_by_param(self): + """order_by is output as $orderby.""" + params = build_list_params(order_by="createTimestamp desc") + assert params["$orderby"] == "createTimestamp desc" + + def test_count_true(self): + """count=True outputs $count=true.""" + params = build_list_params(count=True) + assert params["$count"] == "true" + + def test_count_false(self): + """count=False omits $count.""" + params = build_list_params(count=False) + assert "$count" not in params + + def test_empty_params(self): + """No arguments produces an empty dict.""" + params = build_list_params() + assert params == {} + + def test_all_params(self): + """All parameters are set in a single call.""" + params = build_list_params( + filter_expr="agentID eq 'a'", + search="hello", + select="id,content", + top=5, + skip=10, + order_by="createTimestamp desc", + count=True, + ) + assert params["$filter"] == "agentID eq 'a'" + assert params["$search"] == "hello" + assert "id" in params["$select"] + assert params["$top"] == "5" + assert params["$skip"] == "10" + assert params["$orderby"] == "createTimestamp desc" + assert params["$count"] == "true" + + +# ── extract_value_and_count ────────────────────────────────────────────────── + + +class TestExtractValueAndCount: + + def test_data_format(self): + """Parses the 'data' + 'count' response format.""" + response = { + "data": [{"id": "1"}, {"id": "2"}], + "count": 42, + } + items, total = extract_value_and_count(response) + assert len(items) == 2 + assert total == 42 + + def test_odata_value_format(self): + """Parses the OData 'value' + '@odata.count' response format.""" + response = { + "value": [{"id": "1"}], + "@odata.count": 10, + } + items, total = extract_value_and_count(response) + assert len(items) == 1 + assert total == 10 + + def test_odata_count_key(self): + """Parses the OData '@count' response format (alternative to '@odata.count').""" + response = { + "value": [{"id": "1"}], + "@count": 7, + } + items, total = extract_value_and_count(response) + assert len(items) == 1 + assert total == 7 + + def test_value_takes_precedence_over_data(self): + """When both 'value' and 'data' are present, 'value' wins (OData v4 standard).""" + response = { + "data": [{"id": "from-data"}], + "value": [{"id": "from-value"}], + } + items, _ = extract_value_and_count(response) + assert items[0]["id"] == "from-value" + + def test_missing_count_returns_none(self): + """Missing count keys return None.""" + response = {"data": []} + _, total = extract_value_and_count(response) + assert total is None + + def test_empty_response(self): + """Empty dict returns empty list and None count.""" + items, total = extract_value_and_count({}) + assert items == [] + assert total is None diff --git a/tests/core/unit/telemetry/test_module.py b/tests/core/unit/telemetry/test_module.py index 9d44e31..beb79af 100644 --- a/tests/core/unit/telemetry/test_module.py +++ b/tests/core/unit/telemetry/test_module.py @@ -47,12 +47,13 @@ def test_module_in_collection(self): def test_all_modules_present(self): """Test that all expected modules are present.""" all_modules = list(Module) - assert len(all_modules) == 6 + assert len(all_modules) == 7 assert Module.AICORE in all_modules assert Module.AUDITLOG in all_modules assert Module.DESTINATION in all_modules assert Module.OBJECTSTORE in all_modules assert Module.DMS in all_modules + assert Module.AGENT_MEMORY in all_modules def test_module_iteration(self): """Test iterating over Module enum.""" diff --git a/tests/core/unit/telemetry/test_operation.py b/tests/core/unit/telemetry/test_operation.py index fc5f0a5..4a5337d 100644 --- a/tests/core/unit/telemetry/test_operation.py +++ b/tests/core/unit/telemetry/test_operation.py @@ -132,5 +132,6 @@ def test_operation_iteration(self): def test_operation_count(self): """Test that we have the expected number of operations.""" all_operations = list(Operation) - # 3 auditlog + 11 destination + 10 certificate + 10 fragment + 8 objectstore + 2 aicore + 23 dms = 67 - assert len(all_operations) == 67 + # 3 auditlog + 11 destination + 10 certificate + 10 fragment + # + 8 objectstore + 2 aicore + 23 dms + 13 agent_memory = 80 + assert len(all_operations) == 80