Skip to content
87 changes: 87 additions & 0 deletions benchmarks/bench_credential_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
Benchmark: Credential Instance Caching for Azure AD Authentication

Measures the performance difference between:
1. Creating a new DefaultAzureCredential + get_token() each call (old behavior)
2. Reusing a cached DefaultAzureCredential instance (new behavior)

Prerequisites:
- pip install azure-identity azure-core
- az login (for AzureCliCredential to work)

Usage:
python benchmarks/bench_credential_cache.py
"""

from __future__ import annotations

import time
import statistics


def bench_no_cache(n: int) -> list[float]:
Comment thread
jahnvi480 marked this conversation as resolved.
"""Simulate the OLD behavior: new credential per call."""
from azure.identity import DefaultAzureCredential

times = []
for _ in range(n):
start = time.perf_counter()
cred = DefaultAzureCredential()
cred.get_token("https://database.windows.net/.default")
times.append(time.perf_counter() - start)
return times


def bench_with_cache(n: int) -> list[float]:
Comment thread
jahnvi480 marked this conversation as resolved.
"""Simulate the NEW behavior: reuse a single credential instance."""
from azure.identity import DefaultAzureCredential

cred = DefaultAzureCredential()
times = []
for _ in range(n):
start = time.perf_counter()
cred.get_token("https://database.windows.net/.default")
times.append(time.perf_counter() - start)
return times


def report(label: str, times: list[float]) -> None:
Comment thread
jahnvi480 marked this conversation as resolved.
print(f"\n{'=' * 50}")
print(f" {label}")
print(f"{'=' * 50}")
print(f" Calls: {len(times)}")
print(f" Total: {sum(times):.3f}s")
print(f" Mean: {statistics.mean(times) * 1000:.1f}ms")
print(f" Median: {statistics.median(times) * 1000:.1f}ms")
print(f" Stdev: {statistics.stdev(times) * 1000:.1f}ms" if len(times) > 1 else "")
print(f" Min: {min(times) * 1000:.1f}ms")
print(f" Max: {max(times) * 1000:.1f}ms")


def main() -> None:
N = 10 # number of calls to benchmark

print("Credential Instance Cache Benchmark")
print(f"Running {N} sequential token acquisitions for each scenario...\n")

try:
print(">>> Without cache (new credential each call)...")
no_cache_times = bench_no_cache(N)
report("WITHOUT credential cache (old behavior)", no_cache_times)

print("\n>>> With cache (reuse credential instance)...")
cache_times = bench_with_cache(N)
report("WITH credential cache (new behavior)", cache_times)

speedup = statistics.mean(no_cache_times) / statistics.mean(cache_times)
saved = (statistics.mean(no_cache_times) - statistics.mean(cache_times)) * 1000
print(f"\n{'=' * 50}")
print(f" SPEEDUP: {speedup:.1f}x ({saved:.0f}ms saved per call)")
print(f"{'=' * 50}")
except Exception as e:
print(f"\nBenchmark failed: {e}")
print("Make sure you are logged in via 'az login' and have azure-identity installed.")


if __name__ == "__main__":
main()
31 changes: 25 additions & 6 deletions mssql_python/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,19 @@

import platform
import struct
import threading
from typing import Tuple, Dict, Optional, List

from mssql_python.logging import logger
from mssql_python.constants import AuthType, ConstantsDDBC

# Module-level credential instance cache.
# Reusing credential objects allows the Azure Identity SDK's built-in
# in-memory token cache to work, avoiding redundant token acquisitions.
# See: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/TOKEN_CACHING.md
_credential_cache: Dict[str, object] = {}
Comment thread
jahnvi480 marked this conversation as resolved.
_credential_cache_lock = threading.Lock()


class AADAuth:
"""Handles Azure Active Directory authentication"""
Expand All @@ -36,12 +44,11 @@ def get_token(auth_type: str) -> bytes:

@staticmethod
def get_raw_token(auth_type: str) -> str:
"""Acquire a fresh raw JWT for the mssql-py-core connection (bulk copy).
"""Acquire a raw JWT for the mssql-py-core connection (bulk copy).

This deliberately does NOT cache the credential or token — each call
creates a new Azure Identity credential instance and requests a token.
A fresh acquisition avoids expired-token errors when bulkcopy() is
called long after the original DDBC connect().
Uses the cached credential instance so the Azure Identity SDK's
built-in token cache can serve a valid token without a round-trip
when the previous token has not yet expired.
"""
_, raw_token = AADAuth._acquire_token(auth_type)
return raw_token
Expand Down Expand Up @@ -83,7 +90,19 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
)

try:
credential = credential_class()
with _credential_cache_lock:
if auth_type not in _credential_cache:
logger.debug(
"get_token: Creating new credential instance for auth_type=%s",
auth_type,
)
_credential_cache[auth_type] = credential_class()
else:
logger.debug(
"get_token: Reusing cached credential instance for auth_type=%s",
auth_type,
)
credential = _credential_cache[auth_type]
raw_token = credential.get_token("https://database.windows.net/.default").token
logger.info(
"get_token: Azure AD token acquired successfully - token_length=%d chars",
Expand Down
205 changes: 205 additions & 0 deletions tests/test_008_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
get_auth_token,
process_connection_string,
extract_auth_type,
_credential_cache,
_credential_cache_lock,
)
from mssql_python.constants import AuthType, ConstantsDDBC
import secrets
Expand Down Expand Up @@ -71,6 +73,14 @@ class exceptions:
del sys.modules[module]


@pytest.fixture(autouse=True)
def clear_credential_cache():
"""Clear the module-level credential cache between tests."""
_credential_cache.clear()
yield
_credential_cache.clear()


class TestAuthType:
def test_auth_type_constants(self):
assert AuthType.INTERACTIVE.value == "activedirectoryinteractive"
Expand Down Expand Up @@ -403,6 +413,201 @@ def test_unsupported_auth(self):
assert extract_auth_type("Server=test;Authentication=SqlPassword;") is None


class TestCredentialInstanceCache:
"""Tests for the credential instance caching behavior."""

def test_credential_reused_across_calls(self):
"""The same credential instance should be returned for repeated calls."""
AADAuth.get_token("default")
assert "default" in _credential_cache
first_instance = _credential_cache["default"]

AADAuth.get_token("default")
assert _credential_cache["default"] is first_instance

def test_different_auth_types_get_separate_instances(self):
"""Each auth type should have its own cached credential."""
AADAuth.get_token("default")
AADAuth.get_token("devicecode")

assert "default" in _credential_cache
assert "devicecode" in _credential_cache
assert _credential_cache["default"] is not _credential_cache["devicecode"]

def test_get_raw_token_uses_cached_credential(self):
"""get_raw_token should also use the cached credential instance."""
AADAuth.get_token("default")
cached = _credential_cache["default"]

AADAuth.get_raw_token("default")
assert _credential_cache["default"] is cached

def test_cache_starts_empty(self):
"""Cache should be empty at the start due to the clear_credential_cache fixture."""
assert len(_credential_cache) == 0

def test_cached_credential_refreshes_token_after_expiry(self):
"""Verify that the cached credential instance returns fresh tokens on each call.

This simulates what happens when Azure Identity SDK refreshes an expired
token internally: because we cache the credential (not the token), each
_acquire_token() call invokes get_token() on the same instance, giving
the SDK the opportunity to return a refreshed token when the old one has
expired.
"""
import sys

azure_identity = sys.modules["azure.identity"]
original = azure_identity.DefaultAzureCredential

call_count = 0
tokens = ["initial_token_abc123", "refreshed_token_xyz789"]

class MockCredentialWithRefresh:
def get_token(self, scope):
nonlocal call_count
idx = min(call_count, len(tokens) - 1)
call_count += 1

class Token:
token = tokens[idx]

return Token()

try:
azure_identity.DefaultAzureCredential = MockCredentialWithRefresh

# First call — gets initial token
_, raw_token_1 = AADAuth._acquire_token("default")
assert raw_token_1 == "initial_token_abc123"
assert call_count == 1

# Same credential instance is cached
cached = _credential_cache["default"]
assert isinstance(cached, MockCredentialWithRefresh)

# Second call — same credential instance, but SDK returns refreshed token
# (simulating post-expiry refresh)
_, raw_token_2 = AADAuth._acquire_token("default")
assert raw_token_2 == "refreshed_token_xyz789"
assert call_count == 2

# Credential instance is still the same (not recreated)
assert _credential_cache["default"] is cached
finally:
azure_identity.DefaultAzureCredential = original


class TestAcquireTokenImportError:
"""Test the ImportError path when azure-identity is not installed."""

def test_import_error_raises_runtime_error(self):
"""_acquire_token raises RuntimeError when azure.identity is missing."""
import sys

# Temporarily remove the mocked azure modules
saved = {}
for mod_name in list(sys.modules):
if mod_name == "azure" or mod_name.startswith("azure."):
saved[mod_name] = sys.modules.pop(mod_name)

# Make the import fail
import builtins

real_import = builtins.__import__

def blocked_import(name, *args, **kwargs):
if name.startswith("azure"):
raise ImportError("No module named 'azure'")
return real_import(name, *args, **kwargs)

builtins.__import__ = blocked_import
try:
with pytest.raises(
RuntimeError, match="Azure authentication libraries are not installed"
):
AADAuth._acquire_token("default")
finally:
builtins.__import__ = real_import
sys.modules.update(saved)


class TestAcquireTokenClientAuthError:
"""Test the ClientAuthenticationError path inside _acquire_token."""

def test_client_auth_error_in_acquire_token(self):
"""ClientAuthenticationError during get_token is wrapped in RuntimeError."""
import sys

azure_identity = sys.modules["azure.identity"]
original = azure_identity.DefaultAzureCredential

from azure.core.exceptions import ClientAuthenticationError

class FailingCredential:
def get_token(self, scope):
raise ClientAuthenticationError("token request denied")

try:
azure_identity.DefaultAzureCredential = FailingCredential
with pytest.raises(RuntimeError, match="Azure AD authentication failed"):
AADAuth._acquire_token("default")
finally:
azure_identity.DefaultAzureCredential = original


class TestProcessAuthParametersEdgeCases:
"""Cover empty-param and no-equals-sign branches."""

def test_empty_and_whitespace_params_skipped(self):
params = ["Server=test", "", " ", "Database=db"]
modified, auth_type = process_auth_parameters(params)
assert "Server=test" in modified
assert "Database=db" in modified
assert auth_type is None

def test_param_without_equals_kept(self):
params = ["Server=test", "SomeFlag", "Database=db"]
modified, auth_type = process_auth_parameters(params)
assert "SomeFlag" in modified
assert "Server=test" in modified


class TestGetAuthTokenEdgeCases:
"""Cover the Windows-interactive and token-failure branches."""

def test_no_auth_type_returns_none(self):
result = get_auth_token(None)
assert result is None

def test_empty_auth_type_returns_none(self):
result = get_auth_token("")
assert result is None

def test_windows_interactive_returns_none(self, monkeypatch):
monkeypatch.setattr(platform, "system", lambda: "Windows")
result = get_auth_token("interactive")
assert result is None

def test_token_acquisition_failure_returns_none(self):
"""When AADAuth.get_token raises, get_auth_token returns None."""
import sys

azure_identity = sys.modules["azure.identity"]
original = azure_identity.DefaultAzureCredential

class FailingCredential:
def __init__(self):
raise RuntimeError("credential creation exploded")

try:
azure_identity.DefaultAzureCredential = FailingCredential
result = get_auth_token("default")
assert result is None
finally:
azure_identity.DefaultAzureCredential = original


Comment thread
jahnvi480 marked this conversation as resolved.
def test_acquire_token_unsupported_auth_type():
with pytest.raises(ValueError, match="Unsupported auth_type 'bogus'"):
AADAuth._acquire_token("bogus")
Expand Down
Loading