diff --git a/lib/galaxy/authnz/psa_authnz.py b/lib/galaxy/authnz/psa_authnz.py index 5d1189b38370..e876f919efe8 100644 --- a/lib/galaxy/authnz/psa_authnz.py +++ b/lib/galaxy/authnz/psa_authnz.py @@ -21,7 +21,9 @@ from sqlalchemy import func from sqlalchemy.exc import IntegrityError -from galaxy import exceptions as galaxy_exceptions +from galaxy import ( + exceptions as galaxy_exceptions, +) from galaxy.config import GalaxyAppConfiguration from galaxy.exceptions import MalformedContents from galaxy.managers import users as user_managers @@ -54,6 +56,19 @@ log = logging.getLogger(__name__) + +def locate_token_expiration(extra_data): + expires = extra_data.get("expires", None) or extra_data.get("expires_in", None) + if expires: + return expires + + refresh_token = extra_data.get("refresh_token") + if refresh_token and isinstance(refresh_token, dict): + return refresh_token.get("expires", None) or refresh_token.get("expires_in", None) + + return None + + # key: a component name which PSA requests. # value: is the name of a class associated with that key. DEFAULTS = {"STRATEGY": "Strategy", "STORAGE": "Storage"} @@ -291,17 +306,7 @@ def refresh(self, trans, user_authnz_token): return False def _try_to_locate_refresh_token_expiration(self, extra_data): - # Try to get expiration from top-level keys - expires = extra_data.get("expires", None) or extra_data.get("expires_in", None) - if expires: - return expires - - # Try to get expiration from refresh_token if it's a dict - refresh_token = extra_data.get("refresh_token") - if refresh_token and isinstance(refresh_token, dict): - return refresh_token.get("expires", None) or refresh_token.get("expires_in", None) - - return None + return locate_token_expiration(extra_data) def authenticate(self, trans, idphint=None) -> "HttpResponseProtocol": on_the_fly_config(trans.sa_session) diff --git a/lib/galaxy/exceptions/__init__.py b/lib/galaxy/exceptions/__init__.py index ec1039bfcaaa..55b869558e46 100644 --- a/lib/galaxy/exceptions/__init__.py +++ b/lib/galaxy/exceptions/__init__.py @@ -176,6 +176,11 @@ class AuthenticationFailed(MessageException): err_code = error_codes_by_name["USER_AUTHENTICATION_FAILED"] +class FileSourceCredentialExpired(MessageException): + status_code = 401 + err_code = error_codes_by_name["FILE_SOURCE_CREDENTIAL_EXPIRED"] + + class AuthenticationRequired(MessageException): status_code = 403 # TODO: as 401 and send WWW-Authenticate: ??? diff --git a/lib/galaxy/exceptions/error_codes.json b/lib/galaxy/exceptions/error_codes.json index f4265ca3962a..f90a2eb878a6 100644 --- a/lib/galaxy/exceptions/error_codes.json +++ b/lib/galaxy/exceptions/error_codes.json @@ -104,6 +104,11 @@ "code": 401001, "message": "Authentication failed, invalid credentials supplied." }, + { + "name": "FILE_SOURCE_CREDENTIAL_EXPIRED", + "code": 401002, + "message": "The OIDC credentials for this file source have expired." + }, { "name": "USER_NO_API_KEY", "code": 403001, diff --git a/lib/galaxy/files/__init__.py b/lib/galaxy/files/__init__.py index cf21e294c036..042a9d5c0b6f 100644 --- a/lib/galaxy/files/__init__.py +++ b/lib/galaxy/files/__init__.py @@ -2,6 +2,7 @@ import os from collections import defaultdict from collections.abc import Callable +from datetime import datetime from typing import ( Any, NamedTuple, @@ -326,7 +327,7 @@ def to_dict( class FileSourceDictifiable(Dictifiable, DictifiableFilesSourceContext): - dict_collection_visible_keys = ("email", "username", "ftp_dir", "preferences", "is_admin") + dict_collection_visible_keys = ("email", "username", "ftp_dir", "preferences", "is_admin", "oidc_access_tokens") def to_dict(self, view="collection", value_mapper: Optional[dict[str, Callable]] = None) -> dict[str, Any]: rval = super().to_dict(view=view, value_mapper=value_mapper) @@ -361,6 +362,12 @@ def app_vault(self) -> dict[str, Any]: ... @property def anonymous(self) -> bool: ... + @property + def oidc_access_tokens(self) -> Optional[dict[str, str]]: ... + + @property + def oidc_access_token_expirations(self) -> dict[str, datetime]: ... + OptionalUserContext = Optional[FileSourcesUserContext] @@ -430,6 +437,35 @@ def file_sources(self): def anonymous(self) -> bool: return self.trans.anonymous + @property + def oidc_access_tokens(self) -> Optional[dict[str, str]]: + """ + Return all available access tokens for the current user. + """ + user = self.trans.user + if not user: + return None + tokens = {} + for authnz_token in user.social_auth: + extra_data = authnz_token.extra_data or {} + access_token = extra_data.get("access_token") + if access_token: + tokens[authnz_token.provider] = access_token + return tokens + + @property + def oidc_access_token_expirations(self) -> dict[str, datetime]: + from galaxy.tools.data_fetch_utils import compute_token_expiry_for_provider + + user = self.trans.user + if not user or not user.social_auth: + return {} + return { + auth.provider: expiry + for auth in user.social_auth + if (expiry := compute_token_expiry_for_provider(user, auth.provider)) is not None + } + class DictFileSourcesUserContext(FileSourcesUserContext, FileSourceDictifiable): def __init__(self, **kwd): @@ -478,3 +514,11 @@ def file_sources(self): @property def anonymous(self) -> bool: return not bool(self._kwd.get("username")) + + @property + def oidc_access_tokens(self) -> Optional[dict[str, str]]: + return self._kwd.get("oidc_access_tokens") + + @property + def oidc_access_token_expirations(self) -> dict[str, datetime]: + return self._kwd.get("oidc_access_token_expirations") or {} diff --git a/lib/galaxy/files/models.py b/lib/galaxy/files/models.py index ec2ec21fd324..bb839bcb38ee 100644 --- a/lib/galaxy/files/models.py +++ b/lib/galaxy/files/models.py @@ -207,6 +207,25 @@ class FilesSourceProperties(StrictModel): ), ), ] = None + oidc_auth_provider: Annotated[ + Optional[str], + Field( + None, + title="OIDC authorization provider", + description=("Specify an OIDC provider key to inject the access token as a Bearer Authorization header."), + ), + ] = None + auth_expires_at: Annotated[ + Optional[str], + Field( + title="Auth expires at", + description=( + "ISO-format UTC datetime at which the OIDC access token used by this source expires." + " Set at serialisation time for sources that resolve an Authorization header from" + " the user's OIDC credentials." + ), + ), + ] = None disable_templating: Annotated[ Optional[bool], Field( diff --git a/lib/galaxy/files/sources/__init__.py b/lib/galaxy/files/sources/__init__.py index 10e4318fc7a1..1e47033682bb 100644 --- a/lib/galaxy/files/sources/__init__.py +++ b/lib/galaxy/files/sources/__init__.py @@ -2,6 +2,10 @@ import builtins import os import time +from datetime import ( + datetime, + timezone, +) from enum import Enum from typing import ( Any, @@ -361,6 +365,44 @@ def _parse_common_props(self, config: FilesSourceProperties): self.requires_groups = config.requires_groups self.disable_templating = config.disable_templating self._validate_security_rules() + self._auth_expires_at: Optional[datetime] = ( + datetime.fromisoformat(config.auth_expires_at) if config.auth_expires_at else None + ) + + def _check_credentials_fresh(self) -> None: + if self._auth_expires_at and datetime.now(timezone.utc) > self._auth_expires_at: + from galaxy.exceptions import FileSourceCredentialExpired + + raise FileSourceCredentialExpired() + + def _compute_auth_expires_at(self, user_context: "OptionalUserContext") -> Optional[datetime]: + if user_context is None: + return None + provider = self.template_config.oidc_auth_provider + if not provider: + return None + expirations = getattr(user_context, "oidc_access_token_expirations", {}) + return expirations.get(provider) + + def _inject_oidc_bearer_token( + self, + http_headers: dict[str, str], + user_context: "OptionalUserContext", + ) -> Optional[dict[str, str]]: + """Return a copy of http_headers with a Bearer token added for the configured OIDC provider. + + Returns None if no provider is configured, no user context is available, or the user has + no token for that provider. Explicitly configured Authorization headers take precedence. + """ + provider = self.template_config.oidc_auth_provider + if not provider or not user_context: + return None + token = (getattr(user_context, "oidc_access_tokens", None) or {}).get(provider) + if not token: + return None + headers = dict(http_headers) + headers.setdefault("Authorization", f"Bearer {token}") + return headers def to_dict(self, for_serialization=False, user_context: "OptionalUserContext" = None) -> dict[str, Any]: rval: dict[str, Any] = { @@ -388,6 +430,13 @@ def to_dict(self, for_serialization=False, user_context: "OptionalUserContext" = context = self._get_runtime_context(user_context=user_context) serialized_config = self._serialize_config(context.config) rval.update(serialized_config) + if self.template_config.oidc_auth_provider is not None and user_context is not None: + updated_headers = self._inject_oidc_bearer_token(dict(rval.get("http_headers") or {}), user_context) + if updated_headers is not None: + rval["http_headers"] = updated_headers + expires_at = self._compute_auth_expires_at(user_context) + if expires_at is not None: + rval["auth_expires_at"] = expires_at.isoformat() return rval def _serialize_config(self, config: TResolvedConfig) -> dict[str, Any]: @@ -427,6 +476,10 @@ def _get_runtime_context( self.template_config = self.template_config.model_copy(update=extra_props) resolved_config = self._evaluate_template_config(user_data) + if self.template_config.oidc_auth_provider and user_context and hasattr(resolved_config, "http_headers"): + updated_headers = self._inject_oidc_bearer_token(dict(resolved_config.http_headers or {}), user_context) + if updated_headers is not None: + resolved_config = resolved_config.model_copy(update={"http_headers": updated_headers}) return FilesSourceRuntimeContext(user_data=user_data, config=resolved_config) def _apply_defaults_to_template( @@ -467,6 +520,7 @@ def list( sort_by: Optional[str] = None, ) -> tuple[list[AnyRemoteEntry], int]: self._check_user_access(user_context) + self._check_credentials_fresh() if not self.supports_pagination and (limit is not None or offset is not None): raise RequestParameterInvalidException("Pagination is not supported by this file source.") if not self.supports_search and query: @@ -524,6 +578,7 @@ def write_from( ) -> str: self._ensure_writeable() self._check_user_access(user_context) + self._check_credentials_fresh() resolved_config = self._get_runtime_context(opts, user_context) return self._write_from(target_path, native_path, resolved_config) or target_path @@ -544,6 +599,7 @@ def realize_to( opts: Optional[FilesSourceOptions] = None, ): self._check_user_access(user_context) + self._check_credentials_fresh() resolved_config = self._get_runtime_context(opts, user_context) self._realize_to(source_path, native_path, resolved_config) diff --git a/lib/galaxy/files/sources/drs.py b/lib/galaxy/files/sources/drs.py index eb3277da903a..487d6b8f96e4 100644 --- a/lib/galaxy/files/sources/drs.py +++ b/lib/galaxy/files/sources/drs.py @@ -58,12 +58,13 @@ def _realize_to( ): user_context = context.user_data.context if context.user_data.context else None config = context.config + headers = dict(config.http_headers) fetch_drs_to_file( source_path, native_path, user_context=user_context, fetch_url_allowlist=self._allowlist, - headers=config.http_headers, + headers=headers or None, force_http=config.force_http, ) diff --git a/lib/galaxy/jobs/__init__.py b/lib/galaxy/jobs/__init__.py index ab61666a234a..31b756e66904 100644 --- a/lib/galaxy/jobs/__init__.py +++ b/lib/galaxy/jobs/__init__.py @@ -1078,12 +1078,18 @@ def tool_directory(self): tool_dir = os.path.abspath(tool_dir) return tool_dir + def _refresh_oidc_tokens_for_job(self, trans: WorkRequestContext) -> None: + authnz_manager = getattr(self.app, "authnz_manager", None) + if authnz_manager and trans.user: + authnz_manager.refresh_expiring_oidc_tokens(trans, trans.user) + @property def job_io(self) -> JobIO: if self._job_io is None: job = self.get_job() work_request = WorkRequestContext(self.app, user=job.user, galaxy_session=job.galaxy_session) user_context = ProvidesFileSourcesUserContext(work_request) + self._refresh_oidc_tokens_for_job(work_request) tool_source = self.tool.tool_source.to_string() if self.tool else None tool_dir = self.tool.tool_dir if self.tool else None self._job_io = JobIO( diff --git a/lib/galaxy/tools/data_fetch.py b/lib/galaxy/tools/data_fetch.py index 2c6b48c7c484..0c975068dcda 100644 --- a/lib/galaxy/tools/data_fetch.py +++ b/lib/galaxy/tools/data_fetch.py @@ -49,7 +49,11 @@ def main(argv=None): args = _arg_parser().parse_args(argv) registry = Registry() registry.load_datatypes(root_dir=args.galaxy_root, config=args.datatypes_registry) - do_fetch(args.request, working_directory=args.working_directory or os.getcwd(), registry=registry) + do_fetch( + args.request, + working_directory=args.working_directory or os.getcwd(), + registry=registry, + ) def do_fetch( diff --git a/lib/galaxy/tools/data_fetch_utils.py b/lib/galaxy/tools/data_fetch_utils.py new file mode 100644 index 000000000000..2b63e20680a1 --- /dev/null +++ b/lib/galaxy/tools/data_fetch_utils.py @@ -0,0 +1,45 @@ +from datetime import ( + datetime, + timezone, +) +from typing import Any + +from galaxy.authnz.psa_authnz import locate_token_expiration +from galaxy.model import User + + +def iter_fetch_urls(value: Any): + if isinstance(value, dict): + if value.get("src") == "url" and "url" in value: + yield value["url"] + for child in value.values(): + yield from iter_fetch_urls(child) + elif isinstance(value, list): + for child in value: + yield from iter_fetch_urls(child) + + +def fetch_uses_authorization_header(request: dict[str, Any], file_sources, user_context) -> bool: + for url in iter_fetch_urls(request): + file_source_path = file_sources.get_file_source_path(url) + serialized = file_source_path.file_source.to_dict(for_serialization=True, user_context=user_context) + http_headers = serialized.get("http_headers") or {} + if http_headers.get("Authorization"): + return True + return False + + +def compute_token_expiry_for_provider(user: User | None, provider: str) -> datetime | None: + """Return the expiry for a specific OIDC provider's token, if available.""" + if user is None or not user.social_auth: + return None + for auth in user.social_auth: + if auth.provider != provider: + continue + extra_data = auth.extra_data or {} + auth_time = extra_data.get("auth_time") + expires = locate_token_expiration(extra_data) + if auth_time is None or expires is None: + return None + return datetime.fromtimestamp(int(auth_time) + int(expires), tz=timezone.utc) + return None diff --git a/test/unit/app/tools/test_data_fetch_utils.py b/test/unit/app/tools/test_data_fetch_utils.py new file mode 100644 index 000000000000..cc8a5a8b2dda --- /dev/null +++ b/test/unit/app/tools/test_data_fetch_utils.py @@ -0,0 +1,64 @@ +from datetime import ( + datetime, + timedelta, + timezone, +) +from typing import cast + +from galaxy.model import User +from galaxy.tools.data_fetch_utils import compute_token_expiry_for_provider + + +class DummyToken: + def __init__(self, provider, expiration_time): + self.provider = provider + now_ts = int(datetime.now(timezone.utc).timestamp()) + self.extra_data = { + "auth_time": now_ts, + "expires": int(expiration_time.timestamp()) - now_ts, + } + + +class DummyUser: + def __init__(self, social_auth): + self.social_auth = social_auth + + +def _truncate_to_seconds(value: datetime) -> datetime: + return value.replace(microsecond=0) + + +def test_compute_token_expiry_for_provider_returns_none_for_no_user(): + assert compute_token_expiry_for_provider(None, "oidc") is None + + +def test_compute_token_expiry_for_provider_returns_none_for_empty_social_auth(): + assert compute_token_expiry_for_provider(cast(User, DummyUser([])), "oidc") is None + + +def test_compute_token_expiry_for_provider_returns_expiry_for_matching_provider(): + expiry = datetime.now(timezone.utc) + timedelta(hours=1) + user = DummyUser([DummyToken("oidc", expiry)]) + assert compute_token_expiry_for_provider(cast(User, user), "oidc") == _truncate_to_seconds(expiry) + + +def test_compute_token_expiry_for_provider_ignores_other_providers(): + expiry = datetime.now(timezone.utc) + timedelta(hours=1) + user = DummyUser([DummyToken("google", expiry)]) + assert compute_token_expiry_for_provider(cast(User, user), "oidc") is None + + +def test_compute_token_expiry_for_provider_returns_correct_expiry_among_multiple_providers(): + oidc_expiry = datetime.now(timezone.utc) + timedelta(hours=2) + google_expiry = datetime.now(timezone.utc) + timedelta(minutes=5) + user = DummyUser([DummyToken("google", google_expiry), DummyToken("oidc", oidc_expiry)]) + result = compute_token_expiry_for_provider(cast(User, user), "oidc") + assert result == _truncate_to_seconds(oidc_expiry) + + +def test_compute_token_expiry_for_provider_returns_none_when_token_missing_auth_time_or_expires(): + token = DummyToken.__new__(DummyToken) + token.provider = "oidc" + token.extra_data = {} + user = DummyUser([token]) + assert compute_token_expiry_for_provider(cast(User, user), "oidc") is None diff --git a/test/unit/authnz/test_psa_authnz.py b/test/unit/authnz/test_psa_authnz.py index 6bd2429d0ebb..6f7de1126c0c 100644 --- a/test/unit/authnz/test_psa_authnz.py +++ b/test/unit/authnz/test_psa_authnz.py @@ -397,6 +397,25 @@ def test_oidc_config_custom_auth_pipeline_and_extra(mock_oidc_config_file, mock_ assert psa_authnz.config["SOCIAL_AUTH_PIPELINE"] == custom_auth_pipeline + tuple(custom_auth_pipeline_extra) +def make_psa_authnz(mock_oidc_config_file, mock_oidc_backend_config_file): + mock_app = MagicMock() + mock_app.config = SimpleNamespace( + oidc_auth_pipeline=None, + oidc_auth_pipeline_extra=None, + oidc=defaultdict(dict), + fixed_delegated_auth=False, + ) + manager = AuthnzManager( + app=mock_app, oidc_config_file=mock_oidc_config_file, oidc_backends_config_file=mock_oidc_backend_config_file + ) + return PSAAuthnz( + provider="oidc", + oidc_config=manager.oidc_config, + oidc_backend_config=manager.oidc_backends_config, + app_config=mock_app.config, + ) + + def test_sync_user_profile_skips_when_account_interface_enabled(): manager = MagicMock() session = MagicMock() diff --git a/test/unit/files/_util.py b/test/unit/files/_util.py index 6e939cafa572..d7103df5aae4 100644 --- a/test/unit/files/_util.py +++ b/test/unit/files/_util.py @@ -15,6 +15,7 @@ TEST_USERNAME = "alice" TEST_EMAIL = "alice@galaxyproject.org" +TEST_OIDC_ACCESS_TOKENS = {"oidc": "test-oidc-token"} def serialize_and_recover(file_sources_o: ConfiguredFileSources, user_context: OptionalUserContext = None): @@ -93,6 +94,7 @@ def user_context_fixture(user_ftp_dir=None, role_names=None, group_names=None, i group_names=group_names or set(), is_admin=is_admin, file_sources=file_sources, + oidc_access_tokens=TEST_OIDC_ACCESS_TOKENS, ) return user_context diff --git a/test/unit/files/drs_oidc_file_sources_conf.yml b/test/unit/files/drs_oidc_file_sources_conf.yml new file mode 100644 index 000000000000..387636f9d1f2 --- /dev/null +++ b/test/unit/files/drs_oidc_file_sources_conf.yml @@ -0,0 +1,5 @@ +- type: drs + id: test_oidc + doc: Test drs repository filesource with OIDC token attachment + http_headers: + Authorization: "Bearer ${user.oidc_access_tokens['oidc']}" diff --git a/test/unit/files/test_drs.py b/test/unit/files/test_drs.py index e00832f681d5..fe9676974df1 100644 --- a/test/unit/files/test_drs.py +++ b/test/unit/files/test_drs.py @@ -5,8 +5,13 @@ from typing import Any from unittest import mock +import pytest import responses +from galaxy.files import ( + DictFileSourcesUserContext, + ProvidesFileSourcesUserContext, +) from ._util import ( assert_realizes_as, assert_realizes_contains, @@ -16,6 +21,87 @@ SCRIPT_DIRECTORY = os.path.abspath(os.path.dirname(__file__)) FILE_SOURCES_CONF = os.path.join(SCRIPT_DIRECTORY, "drs_file_sources_conf.yml") +DRS_OIDC_FILE_SOURCES_CONF = os.path.join(SCRIPT_DIRECTORY, "drs_oidc_file_sources_conf.yml") + + +def test_provides_file_sources_user_context_oidc_access_tokens(): + """ProvidesFileSourcesUserContext.oidc_access_tokens reads all providers from social_auth.""" + + class DummyToken: + def __init__(self, provider, access_token): + self.provider = provider + self.extra_data = {"access_token": access_token} + + class DummyUser: + social_auth = [ + DummyToken("oidc", "oidc-token"), + DummyToken("keycloak", "keycloak-token"), + DummyToken("no_token_provider", None), # skipped — no access_token + ] + + class DummyTrans: + user = DummyUser() + + tokens = ProvidesFileSourcesUserContext(DummyTrans()).oidc_access_tokens + assert tokens == {"oidc": "oidc-token", "keycloak": "keycloak-token"} + + +def test_provides_file_sources_user_context_oidc_access_tokens_anonymous(): + """ProvidesFileSourcesUserContext.oidc_access_tokens returns None for anonymous users.""" + + class DummyTrans: + user = None + + assert ProvidesFileSourcesUserContext(DummyTrans()).oidc_access_tokens is None + + +def test_drs_http_headers_template_expansion(): + """Dict values in http_headers are expanded as templates during file source serialization.""" + oidc_token = "my-token" + file_sources = configured_file_sources(DRS_OIDC_FILE_SOURCES_CONF) + user_context = DictFileSourcesUserContext( + username="alice", + email="alice@galaxyproject.org", + preferences={}, + role_names=set(), + group_names=set(), + is_admin=False, + oidc_access_tokens={"oidc": oidc_token}, + ) + file_sources_dict = file_sources.to_dict(for_serialization=True, user_context=user_context) + drs_source = next(s for s in file_sources_dict["file_sources"] if s.get("type") == "drs") + assert drs_source["http_headers"]["Authorization"] == f"Bearer {oidc_token}" + + +def test_drs_oidc_token_wrong_provider_raises(): + """Referencing a provider the user doesn't have raises KeyError at serialization time.""" + file_sources = configured_file_sources(DRS_OIDC_FILE_SOURCES_CONF) + user_context = DictFileSourcesUserContext( + username="alice", + email="alice@galaxyproject.org", + preferences={}, + role_names=set(), + group_names=set(), + is_admin=False, + oidc_access_tokens={"keycloak": "kc-token"}, + ) + with pytest.raises(KeyError): + file_sources.to_dict(for_serialization=True, user_context=user_context) + + +def test_drs_oidc_token_no_tokens_raises(): + """A user with no OIDC tokens raises TypeError at serialization time.""" + file_sources = configured_file_sources(DRS_OIDC_FILE_SOURCES_CONF) + user_context = DictFileSourcesUserContext( + username="alice", + email="alice@galaxyproject.org", + preferences={}, + role_names=set(), + group_names=set(), + is_admin=False, + ) + with pytest.raises(TypeError): + file_sources.to_dict(for_serialization=True, user_context=user_context) @responses.activate @@ -126,3 +212,65 @@ def access_handler(request): assert_realizes_contains( file_sources, test_url, "PMID:30101859-Cao-2018-TGFBR2-Patient_4", user_context=user_context ) + + +@responses.activate +def test_file_source_drs_attach_oidc_token(): + """When http_headers is configured with a template referencing the user's OIDC token, it is sent as a Bearer header.""" + oidc_token = "MyOIDCAccessToken" + + def drs_repo_handler(request): + assert request.headers["Authorization"] == f"Bearer {oidc_token}" + data = { + "id": "999", + "name": "oidc-test-file", + "access_methods": [ + { + "type": "https", + "access_id": "abc", + } + ], + } + return (200, {}, json.dumps(data)) + + def access_handler(request): + assert request.headers["Authorization"] == f"Bearer {oidc_token}" + access_data = { + "url": "https://my.repository.org/oidcfile.txt", + "headers": [], + } + return (200, {}, json.dumps(access_data)) + + responses.add_callback( + responses.GET, + "https://drs.oidc-example.org/ga4gh/drs/v1/objects/999", + callback=drs_repo_handler, + content_type="application/json", + ) + responses.add_callback( + responses.GET, + "https://drs.oidc-example.org/ga4gh/drs/v1/objects/999/access/abc", + callback=access_handler, + content_type="application/json", + ) + + test_url = "drs://drs.oidc-example.org/999" + + def check_download(request, **kwargs): + response: Any = io.StringIO("hello oidc world") + response.headers = {} + response.geturl = lambda: test_url + return response + + with mock.patch.object(urllib.request, "urlopen", new=check_download): + user_context = DictFileSourcesUserContext( + username="alice", + email="alice@galaxyproject.org", + preferences={}, + role_names=set(), + group_names=set(), + is_admin=False, + oidc_access_tokens={"oidc": oidc_token}, + ) + file_sources = configured_file_sources(DRS_OIDC_FILE_SOURCES_CONF) + assert_realizes_as(file_sources, test_url, "hello oidc world", user_context=user_context) diff --git a/test/unit/webapps/galaxy/services/test_tools_service.py b/test/unit/webapps/galaxy/services/test_tools_service.py new file mode 100644 index 000000000000..4125b84601f7 --- /dev/null +++ b/test/unit/webapps/galaxy/services/test_tools_service.py @@ -0,0 +1,78 @@ +from typing import ( + Any, + cast, +) +from unittest.mock import Mock + +from galaxy.app_unittest_utils import galaxy_mock +from galaxy.files import ( + ConfiguredFileSources, + ConfiguredFileSourcesConf, +) +from galaxy.files.models import FileSourcePluginsConfig +from galaxy.managers.context import ProvidesHistoryContext +from galaxy.model import History +from galaxy.schema.fetch_data import FetchDataPayload +from galaxy.schema.fields import Security +from galaxy.webapps.galaxy.services.tools import ToolsService + + +class _ToolsServiceUnderTest(ToolsService): + def _create(self, trans, payload, **kwd): + return payload + + +class TestToolsService: + def setup_method(self): + self.trans = galaxy_mock.MockTrans() + self.app = self.trans.app + Security.security = self.app.security + self.app.config.check_upload_content = True + self.authnz_manager = Mock() + self.app.authnz_manager = self.authnz_manager + self.trans.init_user_in_database() + history = History(user=self.trans.user) + self.trans.sa_session.add(history) + self.trans.sa_session.commit() + self.trans.set_history(history) + + def test_create_fetch_does_not_refresh_when_fetch_has_no_authorization_header(self): + self.app.file_sources = ConfiguredFileSources( + FileSourcePluginsConfig(), + ConfiguredFileSourcesConf( + conf_dict=[ + { + "type": "http", + "id": "test_plain", + "url_regex": r"^https?://example\.org/", + } + ] + ), + ) + + service = _ToolsServiceUnderTest( + config=self.app.config, + toolbox_search=cast(Any, object()), + security=self.app.security, + history_manager=cast(Any, object()), + ) + payload = FetchDataPayload.model_validate( + { + "history_id": self.app.security.encode_id(self.trans.history.id), + "targets": [ + { + "destination": {"type": "hdas"}, + "elements": [ + { + "src": "url", + "url": "https://example.org/data.txt", + "ext": "txt", + } + ], + } + ], + } + ) + + service.create_fetch(cast(ProvidesHistoryContext, self.trans), payload) + cast(Mock, self.authnz_manager.refresh_expiring_oidc_tokens).assert_not_called()