diff --git a/src/sap_cloud_sdk/agentgateway/_customer.py b/src/sap_cloud_sdk/agentgateway/_customer.py index 0f6ffb4..6f076ce 100644 --- a/src/sap_cloud_sdk/agentgateway/_customer.py +++ b/src/sap_cloud_sdk/agentgateway/_customer.py @@ -25,6 +25,8 @@ IntegrationDependency, MCPTool, ) +from sap_cloud_sdk.agentgateway._token_cache import _TokenCache, compute_expires_at +from sap_cloud_sdk.agentgateway.config import ClientConfig from sap_cloud_sdk.agentgateway.exceptions import AgentGatewaySDKError logger = logging.getLogger(__name__) @@ -211,19 +213,24 @@ def _request_token_mtls( credentials: CustomerCredentials, grant_type: str, timeout: float, + config: ClientConfig, app_tid: str | None = None, extra_data: dict | None = None, -) -> str: +) -> tuple[str, float]: """Make mTLS token request to IAS. Args: credentials: Customer credentials with certificate and private key. grant_type: OAuth2 grant type. + timeout: HTTP timeout in seconds. + config: Client configuration (used to compute cache expiry). app_tid: BTP Application Tenant ID of subscriber (optional). extra_data: Additional form data for the token request. Returns: - Access token string. + Tuple of (access_token, expires_at) where expires_at is a + time.monotonic() value indicating when the cached token should + be refreshed (already includes the configured buffer). Raises: AgentGatewaySDKError: If token request fails. @@ -282,8 +289,10 @@ def _request_token_mtls( f"Token response missing 'access_token'. Keys: {list(token_data.keys())}" ) + expires_at = compute_expires_at(token_data, config) + logger.debug("Token acquired successfully (length: %d)", len(access_token)) - return access_token + return access_token, expires_at except httpx.RequestError as e: raise AgentGatewaySDKError(f"Token request failed: {e}") @@ -292,61 +301,87 @@ def _request_token_mtls( def get_system_token_mtls( credentials: CustomerCredentials, timeout: float, + config: ClientConfig, + cache: _TokenCache, app_tid: str | None = None, ) -> str: """Get system-scoped token using mTLS client credentials flow. - Used for tool discovery where user identity is not needed. + Used for tool discovery where user identity is not needed. Returns + a cached token if still valid; otherwise acquires a fresh one. Args: credentials: Customer credentials. timeout: HTTP timeout in seconds. + config: Client configuration. + cache: Token cache to consult and update. app_tid: BTP Application Tenant ID of subscriber (optional). Returns: System-scoped access token. """ + cached = cache.get_system_token(app_tid) + if cached: + logger.debug("Using cached system token (app_tid=%s)", app_tid) + return cached + logger.info("Acquiring system token via mTLS client credentials") - return _request_token_mtls( + token, expires_at = _request_token_mtls( credentials, grant_type=_GRANT_TYPE_CLIENT_CREDENTIALS, timeout=timeout, + config=config, app_tid=app_tid, extra_data={"response_type": "token"}, ) + cache.set_system_token(token, expires_at, app_tid) + return token def exchange_user_token( credentials: CustomerCredentials, user_token: str, timeout: float, + config: ClientConfig, + cache: _TokenCache, app_tid: str | None = None, ) -> str: """Exchange user token for AGW-scoped token using jwt-bearer grant. Used for tool invocation where user identity must be preserved - for principal propagation. + for principal propagation. Returns a cached exchanged token if + still valid; otherwise acquires a fresh one. Args: credentials: Customer credentials. user_token: User's JWT token to exchange. timeout: HTTP timeout in seconds. + config: Client configuration. + cache: Token cache to consult and update. app_tid: BTP Application Tenant ID of subscriber (optional). Returns: AGW-scoped access token with user identity. """ + cached = cache.get_user_token(user_token, app_tid) + if cached: + logger.debug("Using cached user token (app_tid=%s)", app_tid) + return cached + logger.info("Exchanging user token for AGW-scoped token via jwt-bearer grant") - return _request_token_mtls( + token, expires_at = _request_token_mtls( credentials, grant_type=_GRANT_TYPE_JWT_BEARER, timeout=timeout, + config=config, app_tid=app_tid, extra_data={ "assertion": user_token, "token_format": "jwt", }, ) + cache.set_user_token(user_token, token, expires_at, app_tid) + return token def _build_mcp_url(gateway_url: str, ord_id: str, gt_id: str) -> str: @@ -433,6 +468,8 @@ async def _list_server_tools( async def get_mcp_tools_customer( credentials: CustomerCredentials, timeout: float, + config: ClientConfig, + cache: _TokenCache, app_tid: str | None = None, ) -> list[MCPTool]: """List all MCP tools from servers defined in credentials. @@ -442,6 +479,9 @@ async def get_mcp_tools_customer( Args: credentials: Customer credentials with integrationDependencies. + timeout: HTTP timeout in seconds. + config: Client configuration. + cache: Token cache shared across calls. app_tid: BTP Application Tenant ID of subscriber (optional). Returns: @@ -462,7 +502,7 @@ async def get_mcp_tools_customer( # Get system token for discovery loop = asyncio.get_running_loop() system_token = await loop.run_in_executor( - None, get_system_token_mtls, credentials, timeout, app_tid + None, get_system_token_mtls, credentials, timeout, config, cache, app_tid ) tools: list[MCPTool] = [] @@ -480,7 +520,42 @@ async def get_mcp_tools_customer( server_tools = await _list_server_tools(url, system_token, dep, timeout) tools.extend(server_tools) logger.debug("Loaded %d tool(s) from %s", len(server_tools), dep.ord_id) - except Exception: + except Exception as exc: + unwrapped = _unwrap_exception_group(exc) + if _is_unauthorized(unwrapped): + logger.info( + "401 from %s — invalidating cached system token and retrying", + dep.ord_id, + ) + cache.invalidate_system_token(app_tid) + try: + fresh_token = await loop.run_in_executor( + None, + get_system_token_mtls, + credentials, + timeout, + config, + cache, + app_tid, + ) + server_tools = await _list_server_tools( + url, fresh_token, dep, timeout + ) + tools.extend(server_tools) + # Replace stale token for remaining iterations + system_token = fresh_token + logger.debug( + "Loaded %d tool(s) from %s after retry", + len(server_tools), + dep.ord_id, + ) + continue + except Exception: + logger.exception( + "Failed to load tools from %s after retry — skipping", + dep.ord_id, + ) + continue logger.exception("Failed to load tools from %s — skipping", dep.ord_id) logger.info( @@ -494,6 +569,8 @@ async def call_mcp_tool_customer( tool: MCPTool, user_token: str | None, timeout: float, + config: ClientConfig, + cache: _TokenCache, app_tid: str | None = None, **kwargs, ) -> str: @@ -502,11 +579,16 @@ async def call_mcp_tool_customer( If user_token is provided, exchanges it for an AGW-scoped token to preserve user identity for principal propagation. Otherwise, falls back to system token. + On a 401 from the MCP server, drops the cached token and retries once. + Args: credentials: Customer credentials. tool: MCPTool to invoke. user_token: User's JWT token for principal propagation (optional). If None, system token is used instead (no principal propagation). + timeout: HTTP timeout in seconds. + config: Client configuration. + cache: Token cache shared across calls. app_tid: BTP Application Tenant ID of subscriber (optional). **kwargs: Tool input parameters. @@ -517,12 +599,18 @@ async def call_mcp_tool_customer( loop = asyncio.get_running_loop() - if user_token: - # Exchange user token for AGW-scoped token (with principal propagation) - agw_token = await loop.run_in_executor( - None, exchange_user_token, credentials, user_token, timeout, app_tid - ) - else: + async def _acquire_token() -> str: + if user_token: + return await loop.run_in_executor( + None, + exchange_user_token, + credentials, + user_token, + timeout, + config, + cache, + app_tid, + ) # TODO: IBD workaround - use system token when user_token is not available. # This bypasses principal propagation. Remove this fallback once IBD # supports proper user token flow. @@ -530,13 +618,55 @@ async def call_mcp_tool_customer( "No user_token provided - using system token for tool invocation. " "Principal propagation will NOT work." ) - agw_token = await loop.run_in_executor( - None, get_system_token_mtls, credentials, timeout, app_tid + return await loop.run_in_executor( + None, get_system_token_mtls, credentials, timeout, config, cache, app_tid ) + def _invalidate_token() -> None: + if user_token: + cache.invalidate_user_token(user_token, app_tid) + else: + cache.invalidate_system_token(app_tid) + + last_exc: Exception | None = None + for attempt in (1, 2): + agw_token = await _acquire_token() + try: + return await _invoke_tool(tool, agw_token, timeout, **kwargs) + except Exception as exc: + unwrapped = _unwrap_exception_group(exc) + if _is_unauthorized(unwrapped) and attempt == 1: + logger.info( + "401 from MCP server for tool '%s' — invalidating cached token and retrying", + tool.name, + ) + _invalidate_token() + last_exc = exc + continue + raise + + # Defensive — should not be reachable; second attempt either returns or raises. + raise AgentGatewaySDKError( + f"Tool invocation for '{tool.name}' failed after 401 retry: {last_exc}" + ) + + +async def _invoke_tool( + tool: MCPTool, + auth_token: str, + timeout: float, + **kwargs, +) -> str: + """Open an MCP session to `tool.url` and invoke `tool.name` with `kwargs`. + + Returns the first content block's text, or empty string when content is + empty. Raises whatever the MCP transport / session raises (notably + `httpx.HTTPStatusError` on 401, which the caller uses to drive cache + invalidation and retry). + """ async with httpx.AsyncClient( headers={ - "Authorization": f"Bearer {agw_token}", + "Authorization": f"Bearer {auth_token}", "x-correlation-id": str(uuid.uuid4()), }, timeout=timeout, @@ -556,3 +686,17 @@ async def call_mcp_tool_customer( first = result.content[0] return str(getattr(first, "text", "")) + + +def _unwrap_exception_group(exc: BaseException) -> BaseException: + """Unwrap nested ExceptionGroups to find the underlying cause.""" + while isinstance(exc, BaseExceptionGroup) and exc.exceptions: + exc = exc.exceptions[0] + return exc + + +def _is_unauthorized(exc: BaseException) -> bool: + """Detect a 401 response from the MCP server (httpx-based).""" + if isinstance(exc, httpx.HTTPStatusError): + return exc.response is not None and exc.response.status_code == 401 + return False diff --git a/src/sap_cloud_sdk/agentgateway/_token_cache.py b/src/sap_cloud_sdk/agentgateway/_token_cache.py new file mode 100644 index 0000000..324ff3f --- /dev/null +++ b/src/sap_cloud_sdk/agentgateway/_token_cache.py @@ -0,0 +1,188 @@ +"""Token cache for Agent Gateway customer flow. + +Caches IAS tokens (system + user-exchanged) per client to avoid redundant +mTLS token requests during agentic loops. LoB flow uses BTP Destination +Service which has its own caching, so this module only serves the customer +flow. + +Keying: +- System tokens are keyed by `app_tid` (or "_default" when unset). +- User tokens are keyed by `sha256(user_jwt + "|" + (app_tid or ""))[:16]`. + +The `app_tid` component is required because `_request_token_mtls` includes +it in the form payload, producing a tenant-scoped token. Mixing tokens +across tenants would break principal propagation. + +Thread safety: +Token fetches run in the default `ThreadPoolExecutor` via +`loop.run_in_executor`. CPython GIL makes individual dict / OrderedDict +operations atomic, but compound check-then-set is not. Two concurrent +coroutines for the same key may both miss and both fetch; the race +produces redundant token requests, not corruption. +""" + +import base64 +import hashlib +import json +import logging +import time +from collections import OrderedDict +from dataclasses import dataclass + +from sap_cloud_sdk.agentgateway.config import ClientConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class _CachedToken: + """A cached token with monotonic expiry.""" + + token: str + expires_at: float # time.monotonic() value + + def is_valid(self) -> bool: + """Return True if the token has not yet reached its monotonic expiry.""" + return time.monotonic() < self.expires_at + + +def _parse_jwt_exp(jwt: str) -> int | None: + """Extract `exp` claim (seconds since epoch) from a JWT without verification. + + Returns None if the JWT is malformed or has no `exp` claim. The result + is used only as a hint for cache TTL — never for security decisions. + """ + try: + parts = jwt.split(".") + if len(parts) < 2: + return None + payload_b64 = parts[1] + payload_b64 += "=" * (-len(payload_b64) % 4) + claims = json.loads(base64.urlsafe_b64decode(payload_b64)) + exp = claims.get("exp") + return int(exp) if exp is not None else None + except (ValueError, KeyError, TypeError, json.JSONDecodeError): + return None + + +def compute_expires_at(token_data: dict, config: ClientConfig) -> float: + """Resolve the cache expiry timestamp (monotonic) for a token response. + + Resolution order: + 1. `expires_in` from the response, minus the buffer. + 2. `exp` claim from `id_token` (translated from wall clock to monotonic), + minus the buffer. + 3. Config-provided fallback TTL. + """ + now_mono = time.monotonic() + buffer = config.token_expiry_buffer_seconds + + expires_in = token_data.get("expires_in") + if expires_in is not None: + try: + return now_mono + int(expires_in) - buffer + except (ValueError, TypeError): + pass + + id_token = token_data.get("id_token") + if id_token: + exp = _parse_jwt_exp(id_token) + if exp is not None: + remaining = exp - time.time() + if remaining > buffer: + return now_mono + remaining - buffer + + return now_mono + config.fallback_token_ttl_seconds + + +class _TokenCache: + """Per-client token cache with TTL and LRU eviction. + + Both system and user tokens use OrderedDict for LRU ordering. Keys + include `app_tid` so tenant-scoped tokens never leak across tenants. + """ + + _SYSTEM_DEFAULT_KEY = "_default" + + def __init__(self, config: ClientConfig): + """Initialize empty caches bounded by sizes from `config`.""" + self._config = config + self._system_tokens: OrderedDict[str, _CachedToken] = OrderedDict() + self._user_tokens: OrderedDict[str, _CachedToken] = OrderedDict() + + # --- System Token --- + + def get_system_token(self, app_tid: str | None) -> str | None: + """Return a valid cached system token for `app_tid`, or None on miss/expiry.""" + key = app_tid or self._SYSTEM_DEFAULT_KEY + cached = self._system_tokens.get(key) + if cached and cached.is_valid(): + self._system_tokens.move_to_end(key) + return cached.token + if cached: + del self._system_tokens[key] + return None + + def set_system_token( + self, token: str, expires_at: float, app_tid: str | None + ) -> None: + """Cache a system token under `app_tid`; evict LRU once size exceeds limit.""" + key = app_tid or self._SYSTEM_DEFAULT_KEY + self._system_tokens[key] = _CachedToken(token=token, expires_at=expires_at) + self._system_tokens.move_to_end(key) + while len(self._system_tokens) > self._config.max_system_token_cache_size: + evicted, _ = self._system_tokens.popitem(last=False) + logger.debug("System token cache full — evicted '%s'", evicted) + + def invalidate_system_token(self, app_tid: str | None) -> None: + """Drop the cached system token for `app_tid` (no-op if absent).""" + key = app_tid or self._SYSTEM_DEFAULT_KEY + if self._system_tokens.pop(key, None): + logger.debug("Invalidated system token (app_tid=%s)", app_tid) + + # --- User Tokens --- + + def get_user_token(self, user_jwt: str, app_tid: str | None) -> str | None: + """Return a valid cached exchanged token for `(user_jwt, app_tid)`, or None.""" + key = self._hash_key(user_jwt, app_tid) + cached = self._user_tokens.get(key) + if cached and cached.is_valid(): + self._user_tokens.move_to_end(key) + return cached.token + if cached: + del self._user_tokens[key] + return None + + def set_user_token( + self, + user_jwt: str, + token: str, + expires_at: float, + app_tid: str | None, + ) -> None: + """Cache an exchanged user token; evict LRU once size exceeds limit.""" + key = self._hash_key(user_jwt, app_tid) + self._user_tokens[key] = _CachedToken(token=token, expires_at=expires_at) + self._user_tokens.move_to_end(key) + while len(self._user_tokens) > self._config.max_user_token_cache_size: + evicted, _ = self._user_tokens.popitem(last=False) + logger.debug("User token cache full — evicted '%s'", evicted) + + def invalidate_user_token(self, user_jwt: str, app_tid: str | None) -> None: + """Drop the cached user token for `(user_jwt, app_tid)` (no-op if absent).""" + key = self._hash_key(user_jwt, app_tid) + if self._user_tokens.pop(key, None): + logger.debug("Invalidated user token (app_tid=%s)", app_tid) + + # --- Maintenance --- + + def clear(self) -> None: + """Drop all cached tokens. Forces a fresh fetch on next access.""" + self._system_tokens.clear() + self._user_tokens.clear() + + @staticmethod + def _hash_key(user_jwt: str, app_tid: str | None) -> str: + """Derive a short, stable cache key from `(user_jwt, app_tid)` via sha256.""" + material = f"{user_jwt}|{app_tid or ''}" + return hashlib.sha256(material.encode()).hexdigest()[:16] diff --git a/src/sap_cloud_sdk/agentgateway/agw_client.py b/src/sap_cloud_sdk/agentgateway/agw_client.py index a601d88..a9ada19 100644 --- a/src/sap_cloud_sdk/agentgateway/agw_client.py +++ b/src/sap_cloud_sdk/agentgateway/agw_client.py @@ -11,6 +11,7 @@ from typing import Callable from sap_cloud_sdk.agentgateway._models import MCPTool +from sap_cloud_sdk.agentgateway._token_cache import _TokenCache from sap_cloud_sdk.agentgateway.config import ClientConfig from sap_cloud_sdk.agentgateway._customer import ( detect_customer_agent_credentials, @@ -85,6 +86,17 @@ def __init__( """ self._tenant_subdomain = tenant_subdomain self._config = config or ClientConfig() + self._token_cache = _TokenCache(self._config) + + @record_metrics(Module.AGENTGATEWAY, Operation.AGENTGATEWAY_CLEAR_TOKEN_CACHE) + def clear_token_cache(self) -> None: + """Drop all cached tokens. Forces a fresh token fetch on the next call. + + Useful when external state (revoked credentials, tenant change) makes + cached tokens unsafe to reuse, or for testing. No-op for LoB flow, + which delegates caching to BTP Destination Service. + """ + self._token_cache.clear() @staticmethod def _resolve_value( @@ -158,7 +170,11 @@ async def list_mcp_tools( ) credentials = load_customer_credentials(credentials_path) return await get_mcp_tools_customer( - credentials, self._config.timeout, app_tid + credentials, + self._config.timeout, + self._config, + self._token_cache, + app_tid, ) # LoB flow - requires tenant_subdomain @@ -251,6 +267,8 @@ async def call_mcp_tool( tool, resolved_user_token, self._config.timeout, + self._config, + self._token_cache, app_tid, **kwargs, ) diff --git a/src/sap_cloud_sdk/agentgateway/config.py b/src/sap_cloud_sdk/agentgateway/config.py index 427f96b..b44af1f 100644 --- a/src/sap_cloud_sdk/agentgateway/config.py +++ b/src/sap_cloud_sdk/agentgateway/config.py @@ -3,6 +3,10 @@ from dataclasses import dataclass DEFAULT_TIMEOUT_SECONDS = 60.0 +DEFAULT_TOKEN_EXPIRY_BUFFER_SECONDS = 60 +DEFAULT_MAX_USER_TOKEN_CACHE_SIZE = 10 +DEFAULT_MAX_SYSTEM_TOKEN_CACHE_SIZE = 10 +DEFAULT_FALLBACK_TOKEN_TTL_SECONDS = 300 @dataclass @@ -12,6 +16,20 @@ class ClientConfig: Attributes: timeout: HTTP timeout in seconds for token requests and MCP server calls. Defaults to 60 seconds. + token_expiry_buffer_seconds: Refresh tokens this many seconds before + their reported expiry. Defaults to 60 seconds. + max_user_token_cache_size: Maximum number of user tokens cached + per client. LRU eviction once exceeded. Defaults to 10. + max_system_token_cache_size: Maximum number of system tokens cached + per client (one per app_tid). LRU eviction once exceeded. + Defaults to 10. + fallback_token_ttl_seconds: TTL applied when neither `expires_in` + nor a parseable `id_token` exp claim is available in the token + response. Defaults to 300 seconds. """ timeout: float = DEFAULT_TIMEOUT_SECONDS + token_expiry_buffer_seconds: int = DEFAULT_TOKEN_EXPIRY_BUFFER_SECONDS + max_user_token_cache_size: int = DEFAULT_MAX_USER_TOKEN_CACHE_SIZE + max_system_token_cache_size: int = DEFAULT_MAX_SYSTEM_TOKEN_CACHE_SIZE + fallback_token_ttl_seconds: int = DEFAULT_FALLBACK_TOKEN_TTL_SECONDS diff --git a/src/sap_cloud_sdk/core/telemetry/operation.py b/src/sap_cloud_sdk/core/telemetry/operation.py index 8619145..b114523 100644 --- a/src/sap_cloud_sdk/core/telemetry/operation.py +++ b/src/sap_cloud_sdk/core/telemetry/operation.py @@ -107,6 +107,7 @@ class Operation(str, Enum): # Agent Gateway Operations AGENTGATEWAY_LIST_MCP_TOOLS = "list_mcp_tools" AGENTGATEWAY_CALL_MCP_TOOL = "call_mcp_tool" + AGENTGATEWAY_CLEAR_TOKEN_CACHE = "clear_token_cache" # Agent Memory Operations AGENT_MEMORY_ADD_MEMORY = "add_memory" diff --git a/tests/agentgateway/unit/test_agw_client.py b/tests/agentgateway/unit/test_agw_client.py index c60e8bb..f54c242 100644 --- a/tests/agentgateway/unit/test_agw_client.py +++ b/tests/agentgateway/unit/test_agw_client.py @@ -1,7 +1,9 @@ """Unit tests for Agent Gateway client.""" -from unittest.mock import patch, AsyncMock +import time +from unittest.mock import patch, AsyncMock, MagicMock +import httpx import pytest from sap_cloud_sdk.agentgateway import ( @@ -10,6 +12,10 @@ MCPTool, AgentGatewaySDKError, ) +from sap_cloud_sdk.agentgateway._models import ( + CustomerCredentials, + IntegrationDependency, +) # ============================================================ @@ -411,3 +417,273 @@ async def test_returns_result_from_lob_flow(self, mock_tool): ) assert result == "Success: Order created" + + +# ============================================================ +# Test: Token cache behavior through the public API +# ============================================================ + + +def _customer_credentials() -> CustomerCredentials: + """Build a minimal CustomerCredentials fixture for cache-behavior tests.""" + return CustomerCredentials( + token_service_url="https://ias.example.com/oauth2/token", + client_id="test-client", + certificate="cert", + private_key="key", + gateway_url="https://agw.example.com", + integration_dependencies=[ + IntegrationDependency( + ord_id="sap.test:apiResource:demo:v1", + global_tenant_id="250695", + ), + ], + ) + + +def _build_streaming_mocks( + initialize_side_effect=None, + call_tool_side_effect=None, + list_tools_side_effect=None, +): + """Build the chain of mocks needed to drive customer flow MCP calls.""" + http_client = AsyncMock() + http_client.__aenter__ = AsyncMock(return_value=http_client) + http_client.__aexit__ = AsyncMock(return_value=None) + + stream_ctx = AsyncMock() + stream_ctx.__aenter__ = AsyncMock(return_value=(AsyncMock(), AsyncMock(), None)) + stream_ctx.__aexit__ = AsyncMock(return_value=None) + + session = AsyncMock() + if initialize_side_effect is not None: + session.initialize = AsyncMock(side_effect=initialize_side_effect) + else: + init_result = MagicMock() + init_result.serverInfo.name = "demo-server" + session.initialize = AsyncMock(return_value=init_result) + + if list_tools_side_effect is not None: + session.list_tools = AsyncMock(side_effect=list_tools_side_effect) + else: + list_result = MagicMock() + list_result.tools = [] + session.list_tools = AsyncMock(return_value=list_result) + + if call_tool_side_effect is not None: + session.call_tool = AsyncMock(side_effect=call_tool_side_effect) + else: + call_result = MagicMock() + content = MagicMock() + content.text = "ok" + call_result.content = [content] + session.call_tool = AsyncMock(return_value=call_result) + + session_ctx = AsyncMock() + session_ctx.__aenter__ = AsyncMock(return_value=session) + session_ctx.__aexit__ = AsyncMock(return_value=None) + + return http_client, stream_ctx, session_ctx + + +def _make_401() -> httpx.HTTPStatusError: + """Construct an httpx 401 HTTPStatusError for simulating MCP auth failures.""" + request = httpx.Request("POST", "https://example.com") + response = httpx.Response(401, request=request) + return httpx.HTTPStatusError("Unauthorized", request=request, response=response) + + +def _patch_customer_flow(token_request_side_effect): + """Patch detection/loading + IAS request + MCP transport for customer flow. + + Returns the http/stream/session mocks plus the IAS request mock so callers + can assert on call counts. + """ + http_client, stream_ctx, session_ctx = _build_streaming_mocks() + + request_mock = MagicMock(side_effect=token_request_side_effect) + + patches = [ + patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value="/path/to/credentials", + ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.load_customer_credentials", + return_value=_customer_credentials(), + ), + patch( + "sap_cloud_sdk.agentgateway._customer._request_token_mtls", + request_mock, + ), + patch("httpx.AsyncClient", return_value=http_client), + patch( + "sap_cloud_sdk.agentgateway._customer.streamable_http_client", + return_value=stream_ctx, + ), + patch( + "sap_cloud_sdk.agentgateway._customer.ClientSession", + return_value=session_ctx, + ), + ] + return patches, request_mock, session_ctx + + +class TestTokenCacheBehavior: + """Cache behavior verified through AgentGatewayClient public API.""" + + @pytest.mark.asyncio + async def test_list_mcp_tools_twice_hits_ias_once(self, mock_tool): + """Two list_mcp_tools calls share one cached system token.""" + patches, request_mock, _ = _patch_customer_flow( + token_request_side_effect=lambda *a, **kw: ( + "system-token", + time.monotonic() + 600, + ) + ) + + with patches[0], patches[1], patches[2], patches[3], patches[4], patches[5]: + agw_client = create_client() + + await agw_client.list_mcp_tools() + await agw_client.list_mcp_tools() + + assert request_mock.call_count == 1 + + @pytest.mark.asyncio + async def test_call_mcp_tool_twice_same_user_token_hits_ias_once(self, mock_tool): + """Two call_mcp_tool calls with same user_token reuse exchanged token.""" + patches, request_mock, _ = _patch_customer_flow( + token_request_side_effect=lambda *a, **kw: ( + "exchanged-token", + time.monotonic() + 600, + ) + ) + + with patches[0], patches[1], patches[2], patches[3], patches[4], patches[5]: + agw_client = create_client() + + await agw_client.call_mcp_tool(tool=mock_tool, user_token="user-jwt-A") + await agw_client.call_mcp_tool(tool=mock_tool, user_token="user-jwt-A") + + assert request_mock.call_count == 1 + + @pytest.mark.asyncio + async def test_different_user_tokens_isolated(self, mock_tool): + """Different user_tokens trigger separate exchanges.""" + patches, request_mock, _ = _patch_customer_flow( + token_request_side_effect=[ + ("tok-A", time.monotonic() + 600), + ("tok-B", time.monotonic() + 600), + ] + ) + + with patches[0], patches[1], patches[2], patches[3], patches[4], patches[5]: + agw_client = create_client() + + await agw_client.call_mcp_tool(tool=mock_tool, user_token="user-jwt-A") + await agw_client.call_mcp_tool(tool=mock_tool, user_token="user-jwt-B") + + assert request_mock.call_count == 2 + + @pytest.mark.asyncio + async def test_app_tid_isolation(self, mock_tool): + """Same user_token across different app_tid values stays isolated.""" + patches, request_mock, _ = _patch_customer_flow( + token_request_side_effect=[ + ("tok-tenant-a", time.monotonic() + 600), + ("tok-tenant-b", time.monotonic() + 600), + ] + ) + + with patches[0], patches[1], patches[2], patches[3], patches[4], patches[5]: + agw_client = create_client() + + await agw_client.call_mcp_tool( + tool=mock_tool, user_token="user-jwt", app_tid="tenant-a" + ) + await agw_client.call_mcp_tool( + tool=mock_tool, user_token="user-jwt", app_tid="tenant-b" + ) + + assert request_mock.call_count == 2 + + @pytest.mark.asyncio + async def test_clear_token_cache_forces_refetch(self, mock_tool): + """clear_token_cache drops cached tokens, next call refetches.""" + patches, request_mock, _ = _patch_customer_flow( + token_request_side_effect=lambda *a, **kw: ( + "any-token", + time.monotonic() + 600, + ) + ) + + with patches[0], patches[1], patches[2], patches[3], patches[4], patches[5]: + agw_client = create_client() + + await agw_client.call_mcp_tool(tool=mock_tool, user_token="user-jwt") + agw_client.clear_token_cache() + await agw_client.call_mcp_tool(tool=mock_tool, user_token="user-jwt") + + assert request_mock.call_count == 2 + + @pytest.mark.asyncio + async def test_401_invalidates_cache_and_retries(self, mock_tool): + """A 401 from the MCP server drops the cached token and retries once.""" + http_client, stream_ctx, _ = _build_streaming_mocks() + + # First call_tool raises 401, second returns success + success = MagicMock() + content = MagicMock() + content.text = "ok-after-retry" + success.content = [content] + + session = AsyncMock() + init_result = MagicMock() + init_result.serverInfo.name = "demo-server" + session.initialize = AsyncMock(return_value=init_result) + session.call_tool = AsyncMock(side_effect=[_make_401(), success]) + session_ctx = AsyncMock() + session_ctx.__aenter__ = AsyncMock(return_value=session) + session_ctx.__aexit__ = AsyncMock(return_value=None) + + request_mock = MagicMock( + side_effect=[ + ("stale-token", time.monotonic() + 600), + ("fresh-token", time.monotonic() + 600), + ] + ) + + with ( + patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value="/path/to/credentials", + ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.load_customer_credentials", + return_value=_customer_credentials(), + ), + patch( + "sap_cloud_sdk.agentgateway._customer._request_token_mtls", + request_mock, + ), + patch("httpx.AsyncClient", return_value=http_client), + patch( + "sap_cloud_sdk.agentgateway._customer.streamable_http_client", + return_value=stream_ctx, + ), + patch( + "sap_cloud_sdk.agentgateway._customer.ClientSession", + return_value=session_ctx, + ), + ): + agw_client = create_client() + + result = await agw_client.call_mcp_tool( + tool=mock_tool, user_token="user-jwt" + ) + + assert result == "ok-after-retry" + # Stale exchange + fresh exchange after invalidation + assert request_mock.call_count == 2 + diff --git a/tests/agentgateway/unit/test_customer.py b/tests/agentgateway/unit/test_customer.py index 4ed170b..57a3cfa 100644 --- a/tests/agentgateway/unit/test_customer.py +++ b/tests/agentgateway/unit/test_customer.py @@ -16,6 +16,8 @@ _CREDENTIALS_PATH_ENV, _CREDENTIALS_DEFAULT_PATH, ) +from sap_cloud_sdk.agentgateway._token_cache import _TokenCache +from sap_cloud_sdk.agentgateway.config import ClientConfig from sap_cloud_sdk.agentgateway._models import ( CustomerCredentials, IntegrationDependency, @@ -305,7 +307,9 @@ def test_requests_client_credentials_token(self, credentials): mock_client.post.return_value = mock_response mock_client_class.return_value = mock_client - result = get_system_token_mtls(credentials, timeout=60.0) + result = get_system_token_mtls( + credentials, timeout=60.0, config=ClientConfig(), cache=_TokenCache(ClientConfig()) + ) assert result == "system-token-123" mock_client.post.assert_called_once() @@ -332,7 +336,12 @@ def test_raises_on_failed_request(self, credentials): mock_client_class.return_value = mock_client with pytest.raises(AgentGatewaySDKError, match="Token request failed"): - get_system_token_mtls(credentials, timeout=60.0) + get_system_token_mtls( + credentials, + timeout=60.0, + config=ClientConfig(), + cache=_TokenCache(ClientConfig()), + ) # ============================================================ @@ -374,7 +383,13 @@ def test_exchanges_user_token_with_jwt_bearer(self, credentials): mock_client.post.return_value = mock_response mock_client_class.return_value = mock_client - result = exchange_user_token(credentials, "user-jwt-token", timeout=60.0) + result = exchange_user_token( + credentials, + "user-jwt-token", + timeout=60.0, + config=ClientConfig(), + cache=_TokenCache(ClientConfig()), + ) assert result == "exchanged-token-123" call_args = mock_client.post.call_args @@ -403,7 +418,12 @@ def test_passes_app_tid_when_provided(self, credentials): mock_client_class.return_value = mock_client result = exchange_user_token( - credentials, "user-jwt", timeout=60.0, app_tid="test-tid" + credentials, + "user-jwt", + timeout=60.0, + config=ClientConfig(), + cache=_TokenCache(ClientConfig()), + app_tid="test-tid", ) assert result == "token-with-tid" @@ -451,7 +471,12 @@ async def test_raises_when_empty_dependencies(self): with pytest.raises( AgentGatewaySDKError, match="integrationDependencies is empty" ): - await get_mcp_tools_customer(credentials, timeout=60.0) + await get_mcp_tools_customer( + credentials, + timeout=60.0, + config=ClientConfig(), + cache=_TokenCache(ClientConfig()), + ) @pytest.mark.asyncio async def test_discovers_tools_from_credentials(self, credentials): @@ -477,7 +502,12 @@ async def test_discovers_tools_from_credentials(self, credentials): return_value=mock_tools, ) as mock_list, ): - result = await get_mcp_tools_customer(credentials, timeout=60.0) + result = await get_mcp_tools_customer( + credentials, + timeout=60.0, + config=ClientConfig(), + cache=_TokenCache(ClientConfig()), + ) assert len(result) == 1 assert result[0].name == "list_cost_centers" @@ -525,7 +555,12 @@ async def mock_list_tools(*args, **kwargs): side_effect=mock_list_tools, ), ): - result = await get_mcp_tools_customer(credentials, timeout=60.0) + result = await get_mcp_tools_customer( + credentials, + timeout=60.0, + config=ClientConfig(), + cache=_TokenCache(ClientConfig()), + ) # Should still return tools from server2 assert len(result) == 1 @@ -615,11 +650,21 @@ async def test_exchanges_user_token_before_call(self, credentials, mock_tool): mock_session_class.return_value = mock_session_ctx result = await call_mcp_tool_customer( - credentials, mock_tool, "user-jwt", 60.0, order_id="12345" + credentials, + mock_tool, + "user-jwt", + 60.0, + ClientConfig(), + _TokenCache(ClientConfig()), + order_id="12345", ) assert result == "Order created successfully" - mock_exchange.assert_called_once_with(credentials, "user-jwt", 60.0, None) + mock_exchange.assert_called_once() + args, _ = mock_exchange.call_args + assert args[0] is credentials + assert args[1] == "user-jwt" + assert args[2] == 60.0 @pytest.mark.asyncio async def test_uses_system_token_when_user_token_not_provided( @@ -671,10 +716,20 @@ async def test_uses_system_token_when_user_token_not_provided( # Call without user_token (None) result = await call_mcp_tool_customer( - credentials, mock_tool, None, 60.0, order_id="12345" + credentials, + mock_tool, + None, + 60.0, + ClientConfig(), + _TokenCache(ClientConfig()), + order_id="12345", ) assert result == "Result with system token" # Should use system token, not exchange - mock_system_token.assert_called_once_with(credentials, 60.0, None) + mock_system_token.assert_called_once() + args, _ = mock_system_token.call_args + assert args[0] is credentials + assert args[1] == 60.0 mock_exchange.assert_not_called() + diff --git a/tests/agentgateway/unit/test_token_cache.py b/tests/agentgateway/unit/test_token_cache.py new file mode 100644 index 0000000..056e196 --- /dev/null +++ b/tests/agentgateway/unit/test_token_cache.py @@ -0,0 +1,114 @@ +"""Unit tests for token cache helpers with non-trivial logic. + +Cache class behavior is tested through AgentGatewayClient (test_agw_client.py) +to keep coverage focused on observable functionality. Only `_parse_jwt_exp` +and `compute_expires_at` are exercised here directly because they contain +parsing/branching logic that is hard to drive through the public API. +""" + +import base64 +import json +import time + +from sap_cloud_sdk.agentgateway._token_cache import ( + _parse_jwt_exp, + compute_expires_at, +) +from sap_cloud_sdk.agentgateway.config import ClientConfig + + +def _make_jwt(claims: dict) -> str: + """Build a non-signed JWT for testing (header.payload.signature).""" + header = base64.urlsafe_b64encode(json.dumps({"alg": "none"}).encode()).rstrip(b"=") + payload = base64.urlsafe_b64encode(json.dumps(claims).encode()).rstrip(b"=") + return f"{header.decode()}.{payload.decode()}.signature" + + +class TestParseJwtExp: + """Tests for the unverified JWT `exp` claim parser.""" + + def test_extracts_exp(self): + """Extract `exp` claim from a well-formed JWT payload.""" + jwt = _make_jwt({"exp": 1700000000, "iat": 1699996400}) + assert _parse_jwt_exp(jwt) == 1700000000 + + def test_returns_none_when_exp_missing(self): + """Return None when payload has no `exp` claim.""" + jwt = _make_jwt({"iat": 1699996400}) + assert _parse_jwt_exp(jwt) is None + + def test_returns_none_for_malformed_jwt(self): + """Return None for strings that are not three-part JWTs.""" + assert _parse_jwt_exp("not-a-jwt") is None + assert _parse_jwt_exp("") is None + assert _parse_jwt_exp("only.two") is None + + def test_returns_none_for_garbage_payload(self): + """Return None when the payload segment is not valid base64/JSON.""" + assert _parse_jwt_exp("aaa.@@not-base64@@.bbb") is None + + +class TestComputeExpiresAt: + """Tests for cache expiry resolution from token responses.""" + + def test_uses_expires_in_when_present(self): + """Prefer `expires_in` from the response and subtract the buffer.""" + cfg = ClientConfig(token_expiry_buffer_seconds=60, fallback_token_ttl_seconds=300) + before = time.monotonic() + result = compute_expires_at({"expires_in": 3600}, cfg) + assert before + 3540 - 1 <= result <= before + 3540 + 1 + + def test_expires_in_equal_to_buffer_expires_immediately(self): + """Token whose `expires_in` equals the buffer is treated as already expiring now.""" + cfg = ClientConfig(token_expiry_buffer_seconds=60, fallback_token_ttl_seconds=300) + before = time.monotonic() + result = compute_expires_at({"expires_in": 60}, cfg) + after = time.monotonic() + assert before - 1 <= result <= after + 1 + + def test_expires_in_below_buffer_is_already_stale(self): + """Token whose `expires_in` is below the buffer resolves to a past timestamp.""" + cfg = ClientConfig(token_expiry_buffer_seconds=60, fallback_token_ttl_seconds=300) + before = time.monotonic() + result = compute_expires_at({"expires_in": 30}, cfg) + assert before - 31 <= result <= before - 29 + + def test_falls_back_to_id_token_exp(self): + """Fall back to the `exp` claim of `id_token` when `expires_in` is absent.""" + cfg = ClientConfig(token_expiry_buffer_seconds=60, fallback_token_ttl_seconds=300) + future_exp = int(time.time()) + 600 + jwt = _make_jwt({"exp": future_exp}) + before = time.monotonic() + result = compute_expires_at({"id_token": jwt}, cfg) + assert before + 540 - 5 <= result <= before + 540 + 5 + + def test_uses_fallback_when_no_expiry_info(self): + """Use config fallback TTL when neither `expires_in` nor `id_token` is present.""" + cfg = ClientConfig(token_expiry_buffer_seconds=60, fallback_token_ttl_seconds=300) + before = time.monotonic() + result = compute_expires_at({"access_token": "opaque"}, cfg) + assert before + 300 - 1 <= result <= before + 300 + 1 + + def test_uses_fallback_when_id_token_malformed(self): + """Use fallback TTL when the `id_token` cannot be parsed.""" + cfg = ClientConfig(token_expiry_buffer_seconds=60, fallback_token_ttl_seconds=300) + before = time.monotonic() + result = compute_expires_at({"id_token": "garbage"}, cfg) + assert before + 300 - 1 <= result <= before + 300 + 1 + + def test_uses_fallback_when_id_token_exp_within_buffer(self): + """Skip the `id_token` path when remaining lifetime is below the buffer.""" + # If remaining time is below the buffer, the id_token path is skipped. + cfg = ClientConfig(token_expiry_buffer_seconds=60, fallback_token_ttl_seconds=300) + soon_exp = int(time.time()) + 30 + jwt = _make_jwt({"exp": soon_exp}) + before = time.monotonic() + result = compute_expires_at({"id_token": jwt}, cfg) + assert before + 300 - 1 <= result <= before + 300 + 1 + + def test_handles_invalid_expires_in_value(self): + """Use fallback TTL when `expires_in` is not coercible to int.""" + cfg = ClientConfig(token_expiry_buffer_seconds=60, fallback_token_ttl_seconds=300) + before = time.monotonic() + result = compute_expires_at({"expires_in": "not-a-number"}, cfg) + assert before + 300 - 1 <= result <= before + 300 + 1