Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 162 additions & 18 deletions src/sap_cloud_sdk/agentgateway/_customer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
IntegrationDependency,
MCPTool,
)
from sap_cloud_sdk.agentgateway._token_cache import _TokenCache, compute_expires_at
from sap_cloud_sdk.agentgateway.config import ClientConfig
from sap_cloud_sdk.agentgateway.exceptions import AgentGatewaySDKError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -211,19 +213,24 @@ def _request_token_mtls(
credentials: CustomerCredentials,
grant_type: str,
timeout: float,
config: ClientConfig,
app_tid: str | None = None,
extra_data: dict | None = None,
) -> str:
) -> tuple[str, float]:
"""Make mTLS token request to IAS.

Args:
credentials: Customer credentials with certificate and private key.
grant_type: OAuth2 grant type.
timeout: HTTP timeout in seconds.
config: Client configuration (used to compute cache expiry).
app_tid: BTP Application Tenant ID of subscriber (optional).
extra_data: Additional form data for the token request.

Returns:
Access token string.
Tuple of (access_token, expires_at) where expires_at is a
time.monotonic() value indicating when the cached token should
be refreshed (already includes the configured buffer).

Raises:
AgentGatewaySDKError: If token request fails.
Expand Down Expand Up @@ -282,8 +289,10 @@ def _request_token_mtls(
f"Token response missing 'access_token'. Keys: {list(token_data.keys())}"
)

expires_at = compute_expires_at(token_data, config)

logger.debug("Token acquired successfully (length: %d)", len(access_token))
return access_token
return access_token, expires_at

except httpx.RequestError as e:
raise AgentGatewaySDKError(f"Token request failed: {e}")
Expand All @@ -292,61 +301,87 @@ def _request_token_mtls(
def get_system_token_mtls(
credentials: CustomerCredentials,
timeout: float,
config: ClientConfig,
cache: _TokenCache,
app_tid: str | None = None,
) -> str:
"""Get system-scoped token using mTLS client credentials flow.

Used for tool discovery where user identity is not needed.
Used for tool discovery where user identity is not needed. Returns
a cached token if still valid; otherwise acquires a fresh one.

Args:
credentials: Customer credentials.
timeout: HTTP timeout in seconds.
config: Client configuration.
cache: Token cache to consult and update.
app_tid: BTP Application Tenant ID of subscriber (optional).

Returns:
System-scoped access token.
"""
cached = cache.get_system_token(app_tid)
if cached:
logger.debug("Using cached system token (app_tid=%s)", app_tid)
return cached

logger.info("Acquiring system token via mTLS client credentials")
return _request_token_mtls(
token, expires_at = _request_token_mtls(
credentials,
grant_type=_GRANT_TYPE_CLIENT_CREDENTIALS,
timeout=timeout,
config=config,
app_tid=app_tid,
extra_data={"response_type": "token"},
)
cache.set_system_token(token, expires_at, app_tid)
return token


def exchange_user_token(
credentials: CustomerCredentials,
user_token: str,
timeout: float,
config: ClientConfig,
cache: _TokenCache,
app_tid: str | None = None,
) -> str:
"""Exchange user token for AGW-scoped token using jwt-bearer grant.

Used for tool invocation where user identity must be preserved
for principal propagation.
for principal propagation. Returns a cached exchanged token if
still valid; otherwise acquires a fresh one.

Args:
credentials: Customer credentials.
user_token: User's JWT token to exchange.
timeout: HTTP timeout in seconds.
config: Client configuration.
cache: Token cache to consult and update.
app_tid: BTP Application Tenant ID of subscriber (optional).

Returns:
AGW-scoped access token with user identity.
"""
cached = cache.get_user_token(user_token, app_tid)
if cached:
logger.debug("Using cached user token (app_tid=%s)", app_tid)
return cached

logger.info("Exchanging user token for AGW-scoped token via jwt-bearer grant")
return _request_token_mtls(
token, expires_at = _request_token_mtls(
credentials,
grant_type=_GRANT_TYPE_JWT_BEARER,
timeout=timeout,
config=config,
app_tid=app_tid,
extra_data={
"assertion": user_token,
"token_format": "jwt",
},
)
cache.set_user_token(user_token, token, expires_at, app_tid)
return token


def _build_mcp_url(gateway_url: str, ord_id: str, gt_id: str) -> str:
Expand Down Expand Up @@ -433,6 +468,8 @@ async def _list_server_tools(
async def get_mcp_tools_customer(
credentials: CustomerCredentials,
timeout: float,
config: ClientConfig,
cache: _TokenCache,
app_tid: str | None = None,
) -> list[MCPTool]:
"""List all MCP tools from servers defined in credentials.
Expand All @@ -442,6 +479,9 @@ async def get_mcp_tools_customer(

Args:
credentials: Customer credentials with integrationDependencies.
timeout: HTTP timeout in seconds.
config: Client configuration.
cache: Token cache shared across calls.
app_tid: BTP Application Tenant ID of subscriber (optional).

Returns:
Expand All @@ -462,7 +502,7 @@ async def get_mcp_tools_customer(
# Get system token for discovery
loop = asyncio.get_running_loop()
system_token = await loop.run_in_executor(
None, get_system_token_mtls, credentials, timeout, app_tid
None, get_system_token_mtls, credentials, timeout, config, cache, app_tid
)

tools: list[MCPTool] = []
Expand All @@ -480,7 +520,42 @@ async def get_mcp_tools_customer(
server_tools = await _list_server_tools(url, system_token, dep, timeout)
tools.extend(server_tools)
logger.debug("Loaded %d tool(s) from %s", len(server_tools), dep.ord_id)
except Exception:
except Exception as exc:
unwrapped = _unwrap_exception_group(exc)
if _is_unauthorized(unwrapped):
logger.info(
"401 from %s — invalidating cached system token and retrying",
dep.ord_id,
)
cache.invalidate_system_token(app_tid)
try:
fresh_token = await loop.run_in_executor(
None,
get_system_token_mtls,
credentials,
timeout,
config,
cache,
app_tid,
)
server_tools = await _list_server_tools(
url, fresh_token, dep, timeout
)
tools.extend(server_tools)
# Replace stale token for remaining iterations
system_token = fresh_token
logger.debug(
"Loaded %d tool(s) from %s after retry",
len(server_tools),
dep.ord_id,
)
continue
except Exception:
logger.exception(
"Failed to load tools from %s after retry — skipping",
dep.ord_id,
)
continue
logger.exception("Failed to load tools from %s — skipping", dep.ord_id)

logger.info(
Expand All @@ -494,6 +569,8 @@ async def call_mcp_tool_customer(
tool: MCPTool,
user_token: str | None,
timeout: float,
config: ClientConfig,
cache: _TokenCache,
app_tid: str | None = None,
**kwargs,
) -> str:
Expand All @@ -502,11 +579,16 @@ async def call_mcp_tool_customer(
If user_token is provided, exchanges it for an AGW-scoped token to preserve
user identity for principal propagation. Otherwise, falls back to system token.

On a 401 from the MCP server, drops the cached token and retries once.

Args:
credentials: Customer credentials.
tool: MCPTool to invoke.
user_token: User's JWT token for principal propagation (optional).
If None, system token is used instead (no principal propagation).
timeout: HTTP timeout in seconds.
config: Client configuration.
cache: Token cache shared across calls.
app_tid: BTP Application Tenant ID of subscriber (optional).
**kwargs: Tool input parameters.

Expand All @@ -517,26 +599,74 @@ async def call_mcp_tool_customer(

loop = asyncio.get_running_loop()

if user_token:
# Exchange user token for AGW-scoped token (with principal propagation)
agw_token = await loop.run_in_executor(
None, exchange_user_token, credentials, user_token, timeout, app_tid
)
else:
async def _acquire_token() -> str:
if user_token:
return await loop.run_in_executor(
None,
exchange_user_token,
credentials,
user_token,
timeout,
config,
cache,
app_tid,
)
# TODO: IBD workaround - use system token when user_token is not available.
# This bypasses principal propagation. Remove this fallback once IBD
# supports proper user token flow.
logger.warning(
"No user_token provided - using system token for tool invocation. "
"Principal propagation will NOT work."
)
agw_token = await loop.run_in_executor(
None, get_system_token_mtls, credentials, timeout, app_tid
return await loop.run_in_executor(
None, get_system_token_mtls, credentials, timeout, config, cache, app_tid
)

def _invalidate_token() -> None:
if user_token:
cache.invalidate_user_token(user_token, app_tid)
else:
cache.invalidate_system_token(app_tid)

last_exc: Exception | None = None
for attempt in (1, 2):
agw_token = await _acquire_token()
try:
return await _invoke_tool(tool, agw_token, timeout, **kwargs)
except Exception as exc:
unwrapped = _unwrap_exception_group(exc)
if _is_unauthorized(unwrapped) and attempt == 1:
logger.info(
"401 from MCP server for tool '%s' — invalidating cached token and retrying",
tool.name,
)
_invalidate_token()
last_exc = exc
continue
raise

# Defensive — should not be reachable; second attempt either returns or raises.
raise AgentGatewaySDKError(
f"Tool invocation for '{tool.name}' failed after 401 retry: {last_exc}"
)


async def _invoke_tool(
tool: MCPTool,
auth_token: str,
timeout: float,
**kwargs,
) -> str:
"""Open an MCP session to `tool.url` and invoke `tool.name` with `kwargs`.

Returns the first content block's text, or empty string when content is
empty. Raises whatever the MCP transport / session raises (notably
`httpx.HTTPStatusError` on 401, which the caller uses to drive cache
invalidation and retry).
"""
async with httpx.AsyncClient(
headers={
"Authorization": f"Bearer {agw_token}",
"Authorization": f"Bearer {auth_token}",
"x-correlation-id": str(uuid.uuid4()),
},
timeout=timeout,
Expand All @@ -556,3 +686,17 @@ async def call_mcp_tool_customer(

first = result.content[0]
return str(getattr(first, "text", ""))


def _unwrap_exception_group(exc: BaseException) -> BaseException:
"""Unwrap nested ExceptionGroups to find the underlying cause."""
while isinstance(exc, BaseExceptionGroup) and exc.exceptions:
exc = exc.exceptions[0]
return exc


def _is_unauthorized(exc: BaseException) -> bool:
"""Detect a 401 response from the MCP server (httpx-based)."""
if isinstance(exc, httpx.HTTPStatusError):
return exc.response is not None and exc.response.status_code == 401
return False
Loading