diff --git a/CHANGES/pulp-glue/+aiohttp.feature b/CHANGES/pulp-glue/+aiohttp.feature new file mode 100644 index 000000000..70dac14dc --- /dev/null +++ b/CHANGES/pulp-glue/+aiohttp.feature @@ -0,0 +1 @@ +WIP: Added async api to Pulp glue. diff --git a/CHANGES/pulp-glue/+aiohttp.removal b/CHANGES/pulp-glue/+aiohttp.removal new file mode 100644 index 000000000..4d5165bf5 --- /dev/null +++ b/CHANGES/pulp-glue/+aiohttp.removal @@ -0,0 +1,2 @@ +Replaced requests with aiohttp. +Breaking change: Reworked the contract around the `AuthProvider` to allow authentication to be coded independently of the underlying library. diff --git a/lint_requirements.txt b/lint_requirements.txt index 1246a40c4..df74a7923 100644 --- a/lint_requirements.txt +++ b/lint_requirements.txt @@ -4,9 +4,9 @@ mypy~=1.19.1 shellcheck-py~=0.11.0.1 # Type annotation stubs +types-aiofiles types-pygments types-PyYAML -types-requests types-setuptools types-toml diff --git a/lower_bounds_constraints.lock b/lower_bounds_constraints.lock index a2e1857e1..3aad3a3ed 100644 --- a/lower_bounds_constraints.lock +++ b/lower_bounds_constraints.lock @@ -1,3 +1,5 @@ +aiofiles==25.1.0 +aiohttp==3.12.0 click==8.0.0 packaging==22.0 PyYAML==5.3 diff --git a/pulp-glue/pyproject.toml b/pulp-glue/pyproject.toml index a56f6e5ce..2f9991046 100644 --- a/pulp-glue/pyproject.toml +++ b/pulp-glue/pyproject.toml @@ -22,9 +22,10 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ + "aiofiles>=25.1.0,<25.2", + "aiohttp>=3.12.0,<3.14", "multidict>=6.0.5,<6.8", "packaging>=22.0,<=26.0", # CalVer - "requests>=2.24.0,<2.33", "tomli>=2.0.0,<2.1;python_version<'3.11'", ] diff --git a/pulp-glue/src/pulp_glue/common/openapi.py b/pulp-glue/src/pulp_glue/common/openapi.py index 443817d32..86b5ffc9c 100644 --- a/pulp-glue/src/pulp_glue/common/openapi.py +++ b/pulp-glue/src/pulp_glue/common/openapi.py @@ -12,8 +12,9 @@ from io import BufferedReader from urllib.parse import urlencode, urljoin -import requests -import urllib3 +import aiofiles +import aiofiles.os +import aiohttp from multidict import CIMultiDict, CIMultiDictProxy, MutableMultiMapping from pulp_glue.common import __version__ @@ -136,8 +137,6 @@ def __init__( if cid: self._headers["Correlation-Id"] = cid - self._setup_session() - self._oauth2_lock = asyncio.Lock() self._oauth2_token: str | None = None self._oauth2_expires: datetime = datetime.now() @@ -145,29 +144,6 @@ def __init__( self._patch_api_hook: t.Callable[[t.Any], t.Any] = patch_api_hook or (lambda data: data) self.load_api(refresh_cache=refresh_cache) - def _setup_session(self) -> None: - # This is specific requests library. - - if self._verify_ssl is False: - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - - self._session: requests.Session = requests.session() - # Don't redirect, because carrying auth accross redirects is unsafe. - self._session.max_redirects = 0 - self._session.headers.update(self._headers) - session_settings = self._session.merge_environment_settings( - self._base_url, {}, None, self._verify_ssl, None - ) - self._session.verify = session_settings["verify"] - self._session.proxies = session_settings["proxies"] - - if self._auth_provider is not None and self._auth_provider.can_complete_mutualTLS(): - cert, key = self._auth_provider.tls_credentials() - if key is not None: - self._session.cert = (cert, key) - else: - self._session.cert = cert - @property def base_url(self) -> str: return self._base_url @@ -191,7 +167,10 @@ def ssl_context(self) -> t.Union[ssl.SSLContext, bool]: return _ssl_context def load_api(self, refresh_cache: bool = False) -> None: - # TODO: Find a way to invalidate caches on upstream change + asyncio.run(self._load_api(refresh_cache=refresh_cache)) + + async def _load_api(self, refresh_cache: bool = False) -> None: + # TODO: Find a way to invalidate caches on upstream change. xdg_cache_home: str = os.environ.get("XDG_CACHE_HOME") or "~/.cache" apidoc_cache: str = os.path.join( os.path.expanduser(xdg_cache_home), @@ -203,17 +182,17 @@ def load_api(self, refresh_cache: bool = False) -> None: if refresh_cache: # Fake that we did not find the cache. raise OSError() - with open(apidoc_cache, "rb") as f: - data: bytes = f.read() + async with aiofiles.open(apidoc_cache, mode="rb") as f: + data: bytes = await f.read() self._parse_api(data) except Exception: - # Try again with a freshly downloaded version - data = self._download_api() + # Try again with a freshly downloaded version. + data = await self._download_api() self._parse_api(data) - # Write to cache as it seems to be valid - os.makedirs(os.path.dirname(apidoc_cache), exist_ok=True) - with open(apidoc_cache, "bw") as f: - f.write(data) + # Write to cache as it seems to be valid. + await aiofiles.os.makedirs(os.path.dirname(apidoc_cache), exist_ok=True) + async with aiofiles.open(apidoc_cache, mode="bw") as f: + await f.write(data) def _parse_api(self, data: bytes) -> None: raw_spec = self._patch_api_hook(json.loads(data)) @@ -229,15 +208,18 @@ def _parse_api(self, data: bytes) -> None: if method in METHODS } - def _download_api(self) -> bytes: - try: - response: requests.Response = self._session.get(urljoin(self._base_url, self._doc_path)) - except requests.RequestException as e: - raise OpenAPIError(str(e)) - response.raise_for_status() - if "Correlation-Id" in response.headers: - self._set_correlation_id(response.headers["Correlation-Id"]) - return response.content + async def _download_api(self) -> bytes: + response = await self._send_request( + _Request( + operation_id="", + method="get", + url=urljoin(self._base_url, self._doc_path), + headers=self._headers, + ) + ) + if response.status_code != 200: + raise OpenAPIError(_("Failed to find api docs.")) + return response.body def _set_correlation_id(self, correlation_id: str) -> None: if "Correlation-Id" in self._headers: @@ -249,8 +231,6 @@ def _set_correlation_id(self, correlation_id: str) -> None: ) else: self._headers["Correlation-Id"] = correlation_id - # Do it for requests too... - self._session.headers["Correlation-Id"] = correlation_id def param_spec( self, operation_id: str, param_type: str, required: bool = False @@ -467,7 +447,7 @@ def _render_request( security=security, ) - def _log_request(self, request: _Request) -> None: + async def _log_request(self, request: _Request) -> None: if request.params: qs = urlencode(request.params) self._debug_callback(1, f"{request.operation_id} : {request.method} {request.url}?{qs}") @@ -493,7 +473,6 @@ def _select_proposal( if ( request.security and "Authorization" not in request.headers - and "Authorization" not in self._session.headers and self._auth_provider is not None ): security_schemes: dict[str, dict[str, t.Any]] = self.api_spec["components"][ @@ -565,7 +544,7 @@ async def _fetch_oauth2_token(self, flow: dict[str, t.Any]) -> bool: headers={"Authorization": f"Basic {secret.decode()}"}, data=data, ) - response = self._send_request(request) + response = await self._send_request(request) if response.status_code < 200 or response.status_code >= 300: raise OpenAPIError("Failed to fetch OAuth2 token") result = json.loads(response.body) @@ -574,38 +553,55 @@ async def _fetch_oauth2_token(self, flow: dict[str, t.Any]) -> bool: new_token = True return new_token - def _send_request( + async def _send_request( self, request: _Request, ) -> _Response: - # This function uses requests to translate the _Request into a _Response. + # This function uses aiohttp to translate the _Request into a _Response. + data: aiohttp.FormData | dict[str, t.Any] | str | None + if request.files: + assert isinstance(request.data, dict) + # Maybe assert on the content type header. + data = aiohttp.FormData(default_to_multipart=True) + for key, value in request.data.items(): + data.add_field(key, encode_param(value)) + for key, (name, value, content_type) in request.files.items(): + data.add_field(key, value, filename=name, content_type=content_type) + else: + data = request.data try: - r = self._session.request( - request.method, - request.url, - params=request.params, - headers=request.headers, - data=request.data, - files=request.files, - ) - response = _Response(status_code=r.status_code, headers=r.headers, body=r.content) - except requests.TooManyRedirects as e: - assert e.response is not None + async with aiohttp.ClientSession() as session: + async with session.request( + request.method, + request.url, + params=request.params, + headers=request.headers, + data=data, + ssl=self.ssl_context, + max_redirects=0, + ) as r: + response_body = await r.read() + response = _Response( + status_code=r.status, headers=r.headers, body=response_body + ) + except aiohttp.TooManyRedirects as e: + # We could handle that in the middleware... + assert e.history[-1] is not None raise OpenAPIError( _( "Received redirect to '{new_url} from {old_url}'." " Please check your configuration." ).format( - new_url=e.response.headers["location"], + new_url=e.history[-1].headers["location"], old_url=request.url, ) ) - except requests.RequestException as e: + except aiohttp.ClientResponseError as e: raise OpenAPIError(str(e)) return response - def _log_response(self, response: _Response) -> None: + async def _log_response(self, response: _Response) -> None: self._debug_callback( 1, _("Response: {status_code}").format(status_code=response.status_code) ) @@ -652,6 +648,22 @@ def call( parameters: dict[str, t.Any] | None = None, body: dict[str, t.Any] | None = None, validate_body: bool = True, + ) -> t.Any: + return asyncio.run( + self.async_call( + operation_id=operation_id, + parameters=parameters, + body=body, + validate_body=validate_body, + ) + ) + + async def async_call( + self, + operation_id: str, + parameters: dict[str, t.Any] | None = None, + body: dict[str, t.Any] | None = None, + validate_body: bool = True, ) -> t.Any: """ Make a call to the server. @@ -706,7 +718,7 @@ def call( body, validate_body=validate_body, ) - self._log_request(request) + await self._log_request(request) if self._dry_run and request.method.upper() not in SAFE_METHODS: raise UnsafeCallError(_("Call aborted due to safe mode")) @@ -714,29 +726,25 @@ def call( may_retry = False if proposal := self._select_proposal(request): assert len(proposal) == 1, "More complex security proposals are not implemented." - may_retry = asyncio.run(self._authenticate_request(request, proposal)) + may_retry = await self._authenticate_request(request, proposal) - response = self._send_request(request) + response = await self._send_request(request) if proposal is not None: assert self._auth_provider is not None if may_retry and response.status_code == 401: self._oauth2_token = None - asyncio.run(self._authenticate_request(request, proposal)) - response = self._send_request(request) + await self._authenticate_request(request, proposal) + response = await self._send_request(request) if response.status_code >= 200 and response.status_code < 300: - asyncio.run( - self._auth_provider.auth_success_hook( - proposal, self.api_spec["components"]["securitySchemes"] - ) + await self._auth_provider.auth_success_hook( + proposal, self.api_spec["components"]["securitySchemes"] ) elif response.status_code == 401: - asyncio.run( - self._auth_provider.auth_failure_hook( - proposal, self.api_spec["components"]["securitySchemes"] - ) + await self._auth_provider.auth_failure_hook( + proposal, self.api_spec["components"]["securitySchemes"] ) - self._log_response(response) + await self._log_response(response) return self._parse_response(method_spec, response) diff --git a/pulp-glue/tests/test_auth_provider.py b/pulp-glue/tests/test_auth_provider.py index 0bcde6da0..07eed191d 100644 --- a/pulp-glue/tests/test_auth_provider.py +++ b/pulp-glue/tests/test_auth_provider.py @@ -63,10 +63,7 @@ def test_can_complete_basic(self, provider: AuthProviderBase) -> None: assert provider.can_complete_http_basic() def test_provides_username_and_password(self, provider: AuthProviderBase) -> None: - assert asyncio.run(provider.http_basic_credentials()) == ( - b"user1", - b"password1", - ) + assert asyncio.run(provider.http_basic_credentials()) == (b"user1", b"password1") def test_cannot_complete_mutualTLS(self, provider: AuthProviderBase) -> None: assert not provider.can_complete_mutualTLS() @@ -104,10 +101,7 @@ def test_client_id_needs_client_secret(self) -> None: def test_can_complete_oauth2_client_credentials_and_provide_them(self) -> None: provider = GlueAuthProvider(client_id="client1", client_secret="secret1") assert provider.can_complete_oauth2_client_credentials([]) is True - assert asyncio.run(provider.oauth2_client_credentials()) == ( - b"client1", - b"secret1", - ) + assert asyncio.run(provider.oauth2_client_credentials()) == (b"client1", b"secret1") def test_can_complete_mutualTLS_and_provide_cert(self) -> None: provider = GlueAuthProvider(cert="FAKECERTIFICATE") diff --git a/pulp-glue/tests/test_openapi.py b/pulp-glue/tests/test_openapi.py index 0d9e32668..3b03e828b 100644 --- a/pulp-glue/tests/test_openapi.py +++ b/pulp-glue/tests/test_openapi.py @@ -98,7 +98,7 @@ ).encode() -def mock_send_request(request: _Request) -> _Response: +async def mock_send_request(request: _Request) -> _Response: if request.url.endswith("oauth/token"): assert request.method.lower() == "post" # $ echo -n "client1:secret1" | base64 diff --git a/src/pulpcore/cli/core/task.py b/src/pulpcore/cli/core/task.py index 6ece44767..98151e748 100644 --- a/src/pulpcore/cli/core/task.py +++ b/src/pulpcore/cli/core/task.py @@ -1,13 +1,17 @@ +import asyncio import re from contextlib import suppress from datetime import datetime from pathlib import Path +import aiofiles +import aiohttp import click from pulp_glue.common.context import ( DATETIME_FORMATS, PluginRequirement, + PulpContext, PulpEntityContext, ) from pulp_glue.common.exceptions import PulpException @@ -175,6 +179,20 @@ def cancel( task_ctx.cancel(task_ctx.pulp_href) +async def _download_artifacts( + pulp_ctx: PulpContext, urls: dict[str, str], profile_artifact_dir: Path +) -> None: + async with aiohttp.ClientSession() as session: + for name, url in urls.items(): + profile_artifact_path = profile_artifact_dir / name + click.echo(_("Downloading {path}").format(path=profile_artifact_path)) + async with session.get(url, ssl=pulp_ctx.api.ssl_context) as response: + assert response.status == 200 + async with aiofiles.open(profile_artifact_path, "wb") as fp: + async for chunk in response.content.iter_chunked(1024): + await fp.write(chunk) + + @task.command() @href_option @uuid_option @@ -197,13 +215,7 @@ def profile_artifact_urls( uuid = uuid_match.group("uuid") profile_artifact_dir = Path(".") / f"task_profile-{task_name}-{uuid}" profile_artifact_dir.mkdir(exist_ok=True) - with pulp_ctx.api._session as session: - for name, url in urls.items(): - profile_artifact_path = profile_artifact_dir / name - click.echo(_("Downloading {path}").format(path=profile_artifact_path)) - response = session.get(url) - response.raise_for_status() - profile_artifact_path.write_bytes(response.content) + asyncio.run(_download_artifacts(pulp_ctx, urls, profile_artifact_dir)) @task.command()