Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions packages/google-auth/google/auth/aio/transport/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ async def configure_mtls_channel(self, client_cert_callback=None):
(via GOOGLE_API_USE_CLIENT_CERTIFICATE=true) or auto-enabled (when the env
variable is unset and workload certificates are discovered). In these cases,
the underlying transport will be reconfigured to use mTLS.

Note: This function does nothing if the `aiohttp` library is not
installed.
Important: Calling this method will close any ongoing API requests associated
Expand Down Expand Up @@ -220,11 +219,7 @@ async def _do_configure():
UserWarning,
)

except (
exceptions.ClientCertError,
ImportError,
OSError,
) as caught_exc:
except Exception as caught_exc:
self._is_mtls = False
Comment thread
vverman marked this conversation as resolved.
new_exc = exceptions.MutualTLSChannelError(caught_exc)
raise new_exc from caught_exc
Expand Down Expand Up @@ -586,4 +581,10 @@ async def close(self) -> None:
"""
Close the underlying auth request session.
"""
if self._mtls_init_task and not self._mtls_init_task.done():
self._mtls_init_task.cancel()
try:
await self._mtls_init_task
except asyncio.CancelledError:
pass
await self._auth_request.close()
41 changes: 23 additions & 18 deletions packages/google-auth/google/auth/transport/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from google.auth import exceptions
from google.auth.transport import _mtls_helper
from google.auth.transport import mtls
from google.oauth2 import service_account

try:
Expand Down Expand Up @@ -279,14 +280,18 @@ def my_client_cert_callback():
class SslCredentials:
"""Class for application default SSL credentials.

The behavior is controlled by `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment
variable whose default value is `false`. Client certificate will not be used
unless the environment variable is explicitly set to `true`. See
https://google.aip.dev/auth/4114

If the environment variable is `true`, then for devices with endpoint verification
support, a device certificate will be automatically loaded and mutual TLS will
be established.
The client certificate usage (mutual TLS) is determined by the
`should_use_client_cert` helper. Client certificate will not be used
unless client certificate usage is enabled. This is true if the
`GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is explicitly
set to `"true"`, or if the environment variable is unset/empty but a client
certificate configuration is found (e.g. via the `GOOGLE_API_CERTIFICATE_CONFIG`
environment variable containing a `"workload"` certificate configuration).
See https://google.aip.dev/auth/4114

If client certificate usage is enabled, then for devices with endpoint
verification support, a device certificate will be automatically loaded and
mutual TLS will be established.
See https://cloud.google.com/endpoint-verification/docs/overview.
"""

Expand All @@ -295,11 +300,7 @@ def __init__(self):
if not use_client_cert:
self._is_mtls = False
else:
# Load client SSL credentials.
metadata_path = _mtls_helper._check_config_path(
_mtls_helper.CONTEXT_AWARE_METADATA_PATH
)
self._is_mtls = metadata_path is not None
self._is_mtls = mtls.has_default_client_cert_source()

@property
def ssl_credentials(self):
Comment thread
nbayati marked this conversation as resolved.
Expand All @@ -319,11 +320,15 @@ def ssl_credentials(self):
"""
if self._is_mtls:
try:
_, cert, key, _ = _mtls_helper.get_client_ssl_credentials()
self._ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
except exceptions.ClientCertError as caught_exc:
has_cert, cert, key, _ = _mtls_helper.get_client_ssl_credentials()
if has_cert:
self._ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
else:
self._ssl_credentials = grpc.ssl_channel_credentials()
self._is_mtls = False
except (exceptions.ClientCertError, OSError) as caught_exc:
new_exc = exceptions.MutualTLSChannelError(caught_exc)
raise new_exc from caught_exc
else:
Expand Down
25 changes: 16 additions & 9 deletions packages/google-auth/google/auth/transport/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,42 +443,49 @@ def configure_mtls_channel(self, client_cert_callback=None):

Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel
creation failed for any reason.
creation failed for any reason. The existing session state (such
as adapter mounts) remains unmodified if this error is raised.
"""
use_client_cert = google.auth.transport._mtls_helper.check_use_client_cert()
if not use_client_cert:
self._is_mtls = False
return

try:
import OpenSSL
except ImportError as caught_exc:
self._is_mtls = False
new_exc = exceptions.MutualTLSChannelError(caught_exc)
raise new_exc from caught_exc

try:
(
self._is_mtls,
is_mtls,
cert,
key,
) = google.auth.transport._mtls_helper.get_client_cert_and_key(
client_cert_callback
)

if self._is_mtls:
mtls_adapter = _MutualTlsAdapter(cert, key)
self._cached_cert = cert
self.mount("https://", mtls_adapter)
if is_mtls:
new_adapter = _MutualTlsAdapter(cert, key)
else:
new_adapter = requests.adapters.HTTPAdapter()
except (
exceptions.ClientCertError,
ImportError,
OSError,
OpenSSL.crypto.Error,
) as caught_exc:
self._is_mtls = False
new_exc = exceptions.MutualTLSChannelError(caught_exc)
raise new_exc from caught_exc

self.mount("https://", new_adapter)
self._is_mtls = is_mtls
if is_mtls:
self._cached_cert = cert
else:
if hasattr(self, "_cached_cert"):
del self._cached_cert

def request(
self,
method,
Expand Down
26 changes: 16 additions & 10 deletions packages/google-auth/google/auth/transport/urllib3.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,18 +332,16 @@ def configure_mtls_channel(self, client_cert_callback=None):

Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel
creation failed for any reason.
creation failed for any reason. The existing channel state (the
HTTP client) remains unmodified if this error is raised.
"""
use_client_cert = transport._mtls_helper.check_use_client_cert()
if not use_client_cert:
self._is_mtls = False
return False
else:
self._is_mtls = True

try:
import OpenSSL
except ImportError as caught_exc:
self._is_mtls = False
new_exc = exceptions.MutualTLSChannelError(caught_exc)
raise new_exc from caught_exc

Expand All @@ -353,21 +351,29 @@ def configure_mtls_channel(self, client_cert_callback=None):
)

if found_cert_key:
self.http = _make_mutual_tls_http(cert, key)
self._cached_cert = cert
new_http = _make_mutual_tls_http(cert, key)
new_is_mtls = True
else:
self.http = _make_default_http()
self._is_mtls = False
new_http = _make_default_http()
new_is_mtls = False
except (
Comment thread
vverman marked this conversation as resolved.
exceptions.ClientCertError,
ImportError,
OSError,
OpenSSL.crypto.Error,
) as caught_exc:
self._is_mtls = False
new_exc = exceptions.MutualTLSChannelError(caught_exc)
raise new_exc from caught_exc

self.http = new_http
self._is_mtls = new_is_mtls
self._request.http = new_http
if new_is_mtls:
self._cached_cert = cert
else:
if hasattr(self, "_cached_cert"):
del self._cached_cert

if self._has_user_provided_http:
self._has_user_provided_http = False
warnings.warn(
Expand Down
73 changes: 71 additions & 2 deletions packages/google-auth/tests/transport/aio/test_sessions_mtls.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ async def test_configure_mtls_channel(self):
mock_make_context.assert_called_once_with(
b"fake_cert_data", b"fake_key_data"
)
await session.close()

@pytest.mark.asyncio
async def test_configure_mtls_channel_disabled(self):
Expand All @@ -78,6 +79,7 @@ async def test_configure_mtls_channel_disabled(self):
session = sessions.AsyncAuthorizedSession(mock_creds)
await session.configure_mtls_channel()
assert session._is_mtls is False
await session.close()

@pytest.mark.asyncio
async def test_configure_mtls_channel_invalid_format(self):
Expand All @@ -95,6 +97,7 @@ async def test_configure_mtls_channel_invalid_format(self):

with pytest.raises(exceptions.MutualTLSChannelError):
await session.configure_mtls_channel()
await session.close()

@pytest.mark.asyncio
async def test_configure_mtls_channel_invalud_fields(self):
Expand All @@ -111,6 +114,7 @@ async def test_configure_mtls_channel_invalud_fields(self):
session = sessions.AsyncAuthorizedSession(mock_creds)
await session.configure_mtls_channel()
assert session._is_mtls is False
await session.close()

@pytest.mark.asyncio
async def test_configure_mtls_channel_mock_callback(self):
Expand All @@ -135,11 +139,12 @@ def mock_callback():
await session.configure_mtls_channel(client_cert_callback=mock_callback)

assert session._is_mtls is True
await session.close()

@pytest.mark.asyncio
async def test_configure_mtls_channel_custom_request(self):
"""
Tests that if _auth_request is not an AiohttpRequest, it gracefully falls back to tLS.
"""Tests that if _auth_request is not an AiohttpRequest, _is_mtls is set to False
because we can't configure the custom request with mTLS.
"""
with mock.patch.dict(
os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}
Expand All @@ -161,5 +166,69 @@ async def test_configure_mtls_channel_custom_request(self):
session = sessions.AsyncAuthorizedSession(
Comment thread
vverman marked this conversation as resolved.
mock_creds, auth_request=mock_auth_request
)

await session.configure_mtls_channel()

# If the request handler is not an AiohttpRequest, the library cannot configure
# the connection to use mTLS, so _is_mtls must be False to reflect this unconfigured state.
assert session._is_mtls is False
mock_make_context.assert_called_once_with(
b"fake_cert_data", b"fake_key_data"
)
await session.close()

@pytest.mark.asyncio
async def test_configure_mtls_channel_exception_resets_flag(self):
"""
Tests that self._is_mtls is reset to False if an exception is raised
during configuration.
"""
with mock.patch.dict(
os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}
), mock.patch("os.path.exists") as mock_exists, mock.patch(
"builtins.open", mock.mock_open(read_data=json.dumps(VALID_WORKLOAD_CONFIG))
), mock.patch(
"google.auth.aio.transport.mtls.get_client_cert_and_key"
) as mock_helper, mock.patch(
"google.auth.aio.transport.mtls.make_client_cert_ssl_context"
) as mock_make_context:
mock_exists.return_value = True
mock_helper.return_value = (True, b"fake_cert_data", b"fake_key_data")
mock_make_context.side_effect = exceptions.ClientCertError("Mock error")

mock_creds = mock.AsyncMock(spec=credentials.Credentials)
session = sessions.AsyncAuthorizedSession(mock_creds)

with pytest.raises(exceptions.MutualTLSChannelError):
await session.configure_mtls_channel()

assert session._is_mtls is False
await session.close()

@pytest.mark.asyncio
async def test_configure_mtls_channel_transport_error_resets_flag(self):
"""
Tests that self._is_mtls is reset to False if a TransportError is raised
during configuration.
"""
with mock.patch.dict(
os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}
), mock.patch("os.path.exists") as mock_exists, mock.patch(
"builtins.open", mock.mock_open(read_data=json.dumps(VALID_WORKLOAD_CONFIG))
), mock.patch(
"google.auth.aio.transport.mtls.get_client_cert_and_key"
) as mock_helper, mock.patch(
"google.auth.aio.transport.mtls.make_client_cert_ssl_context"
) as mock_make_context:
mock_exists.return_value = True
mock_helper.return_value = (True, b"fake_cert_data", b"fake_key_data")
mock_make_context.side_effect = exceptions.TransportError("Mock error")

mock_creds = mock.AsyncMock(spec=credentials.Credentials)
session = sessions.AsyncAuthorizedSession(mock_creds)

with pytest.raises(exceptions.MutualTLSChannelError):
await session.configure_mtls_channel()

assert session._is_mtls is False
await session.close()
Loading
Loading