diff --git a/cloud_pipelines_backend/api_router.py b/cloud_pipelines_backend/api_router.py index 6652637..74e1bdb 100644 --- a/cloud_pipelines_backend/api_router.py +++ b/cloud_pipelines_backend/api_router.py @@ -390,6 +390,42 @@ def get_current_user( permissions=permissions, ) + ### Secrets routes + secrets_service = api_server_sql.SecretsApiService() + + router.get("/api/secrets/", tags=["secrets"], **default_config)( + inject_session_dependency( + inject_user_name(secrets_service.list_secrets, parameter_name="user_id") + ) + ) + router.post("/api/secrets/", tags=["secrets"], **default_config)( + add_parameter_annotation_metadata( + inject_session_dependency( + inject_user_name( + secrets_service.create_secret, parameter_name="user_id" + ) + ), + parameter_name="secret_value", + annotation_metadata=fastapi.Body(embed=True), + ) + ) + router.put("/api/secrets/{secret_name}", tags=["secrets"], **default_config)( + add_parameter_annotation_metadata( + inject_session_dependency( + inject_user_name( + secrets_service.update_secret, parameter_name="user_id" + ) + ), + parameter_name="secret_value", + annotation_metadata=fastapi.Body(embed=True), + ) + ) + router.delete("/api/secrets/{secret_name}", tags=["secrets"], **default_config)( + inject_session_dependency( + inject_user_name(secrets_service.delete_secret, parameter_name="user_id") + ) + ) + ### Component library routes component_service = components_api.ComponentService() diff --git a/cloud_pipelines_backend/api_server_sql.py b/cloud_pipelines_backend/api_server_sql.py index e8e0624..7c53b41 100644 --- a/cloud_pipelines_backend/api_server_sql.py +++ b/cloud_pipelines_backend/api_server_sql.py @@ -994,6 +994,129 @@ def get_signed_artifact_url( return GetArtifactSignedUrlResponse(signed_url=signed_url) +# === Secrets Service +@dataclasses.dataclass(kw_only=True) +class SecretInfoResponse: + secret_name: str + created_at: datetime.datetime | None + updated_at: datetime.datetime | None + + @classmethod + def from_db(cls, secret_row: bts.Secret) -> "SecretInfoResponse": + return SecretInfoResponse( + secret_name=secret_row.secret_name, + created_at=secret_row.created_at, + updated_at=secret_row.updated_at, + ) + + +@dataclasses.dataclass(kw_only=True) +class ListSecretsResponse: + secrets: list[SecretInfoResponse] + + +class SecretsApiService: + + def create_secret( + self, + *, + session: orm.Session, + user_id: str, + secret_name: str, + secret_value: str, + ) -> SecretInfoResponse: + secret_name = secret_name.strip() + if not secret_name: + raise ApiServiceError(f"Secret name must not be empty.") + return self._set_secret_value( + session=session, + user_id=user_id, + secret_name=secret_name, + secret_value=secret_value, + raise_if_exists=True, + ) + + def update_secret( + self, + *, + session: orm.Session, + user_id: str, + secret_name: str, + secret_value: str, + ) -> SecretInfoResponse: + return self._set_secret_value( + session=session, + user_id=user_id, + secret_name=secret_name, + secret_value=secret_value, + raise_if_not_exists=True, + ) + + def _set_secret_value( + self, + *, + session: orm.Session, + user_id: str, + secret_name: str, + secret_value: str, + raise_if_not_exists: bool = False, + raise_if_exists: bool = False, + ) -> SecretInfoResponse: + current_time = _get_current_time() + secret = session.get(bts.Secret, (user_id, secret_name)) + if secret: + if raise_if_exists: + raise errors.ItemAlreadyExistsError( + f"Secret with name '{secret_name}' already exists." + ) + secret.secret_value = secret_value + secret.updated_at = current_time + else: + if raise_if_not_exists: + raise errors.ItemNotFoundError( + f"Secret with name '{secret_name}' does not exist." + ) + secret = bts.Secret( + user_id=user_id, + secret_name=secret_name, + secret_value=secret_value, + created_at=current_time, + updated_at=current_time, + ) + session.add(secret) + response = SecretInfoResponse.from_db(secret) + session.commit() + return response + + def delete_secret( + self, + *, + session: orm.Session, + user_id: str, + secret_name: str, + ) -> None: + secret = session.get(bts.Secret, (user_id, secret_name)) + if not secret: + raise errors.ItemNotFoundError( + f"Secret with name '{secret_name}' does not exist." + ) + session.delete(secret) + session.commit() + + def list_secrets( + self, + *, + session: orm.Session, + user_id: str, + ) -> ListSecretsResponse: + secrets = session.scalars( + sql.select(bts.Secret).where(bts.Secret.user_id == user_id) + ).all() + return ListSecretsResponse( + secrets=[SecretInfoResponse.from_db(secret) for secret in secrets] + ) + + # ============ # Idea for how to add deep nested graph: @@ -1005,11 +1128,16 @@ def get_signed_artifact_url( # No. Decided to first do topological sort and then 1-stage generation. +_ArtifactNodeOrDynamicDataType = typing.Union[ + bts.ArtifactNode, structures.DynamicDataArgument +] + + def _recursively_create_all_executions_and_artifacts_root( session: orm.Session, root_task_spec: structures.TaskSpec, ) -> bts.ExecutionNode: - input_artifact_nodes: dict[str, bts.ArtifactNode] = {} + input_artifact_nodes: dict[str, _ArtifactNodeOrDynamicDataType] = {} root_component_spec = root_task_spec.component_ref.spec if not root_component_spec: @@ -1035,12 +1163,8 @@ def _recursively_create_all_executions_and_artifacts_root( raise ApiServiceError( f"root task arguments can only be constants, but got {input_name}={input_argument}. {root_task_spec=}" ) - elif not isinstance(input_argument, str): - raise ApiServiceError( - f"root task constant argument must be a string, but got {input_name}={input_argument}. {root_task_spec=}" - ) # TODO: Support constant input artifacts (artifact IDs) - if input_argument is not None: + elif isinstance(input_argument, str): input_artifact_nodes[input_name] = ( # _construct_constant_artifact_node_and_add_to_session( # session=session, value=input_argument, artifact_type=input_spec.type @@ -1052,6 +1176,12 @@ def _recursively_create_all_executions_and_artifacts_root( # This constant artifact won't be added to the DB # TODO: Actually, they will be added... # We don't need to link this input artifact here. It will be handled downstream. + elif isinstance(input_argument, structures.DynamicDataArgument): + input_artifact_nodes[input_name] = input_argument + else: + raise ApiServiceError( + f"root task constant argument must be a string, but got {input_name}={input_argument}. {root_task_spec=}" + ) root_execution_node = _recursively_create_all_executions_and_artifacts( session=session, @@ -1065,7 +1195,7 @@ def _recursively_create_all_executions_and_artifacts_root( def _recursively_create_all_executions_and_artifacts( session: orm.Session, root_task_spec: structures.TaskSpec, - input_artifact_nodes: dict[str, bts.ArtifactNode], + input_artifact_nodes: dict[str, _ArtifactNodeOrDynamicDataType], ancestors: list[bts.ExecutionNode], ) -> bts.ExecutionNode: root_component_spec = root_task_spec.component_ref.spec @@ -1098,6 +1228,23 @@ def _recursively_create_all_executions_and_artifacts( input_artifact_nodes = dict(input_artifact_nodes) for input_spec in root_component_spec.inputs or []: input_artifact_node = input_artifact_nodes.get(input_spec.name) + if isinstance(input_artifact_node, structures.DynamicDataArgument): + # We don't use these secret arguments, but adding them just in case. + extra_data = root_execution_node.extra_data or {} + dynamic_data_arguments = extra_data.setdefault( + bts.EXECUTION_NODE_EXTRA_DATA_DYNAMIC_DATA_ARGUMENTS_KEY, {} + ) + dynamic_data_arguments[input_spec.name] = input_artifact_node.dynamic_data + if not ( + isinstance(input_artifact_node.dynamic_data, str) + or len(input_artifact_node.dynamic_data) == 1 + ): + raise ApiServiceError( + f"Dynamic data argument must be a string or a dict with a single key set, but got {input_artifact_node.dynamic_data}" + ) + root_execution_node.extra_data = extra_data + # Not adding any artifact link for secret inputs + continue if input_artifact_node is None and not input_spec.optional: if input_spec.default: input_artifact_node = ( @@ -1163,7 +1310,8 @@ def _recursively_create_all_executions_and_artifacts( root_execution_node.container_execution_status = ( bts.ContainerExecutionStatus.QUEUED if all( - artifact_node.artifact_data + not isinstance(artifact_node, bts.ArtifactNode) + or artifact_node.artifact_data for artifact_node in input_artifact_nodes.values() ) else bts.ContainerExecutionStatus.WAITING_FOR_UPSTREAM @@ -1190,10 +1338,12 @@ def _recursively_create_all_executions_and_artifacts( raise ApiServiceError( f"child_task_spec.component_ref.spec is empty. {child_task_spec=}" ) - child_task_input_artifact_nodes: dict[str, bts.ArtifactNode] = {} + child_task_input_artifact_nodes: dict[ + str, _ArtifactNodeOrDynamicDataType + ] = {} for input_spec in child_component_spec.inputs or []: input_argument = (child_task_spec.arguments or {}).get(input_spec.name) - input_artifact_node: bts.ArtifactNode | None = None + input_artifact_node: _ArtifactNodeOrDynamicDataType | None = None if input_argument is None and not input_spec.optional: # Not failing on unconnected required input if there is a default value if input_spec.default is None: @@ -1233,6 +1383,9 @@ def _recursively_create_all_executions_and_artifacts( # artifact_type=input_spec.type, # ) # ) + elif isinstance(input_argument, structures.DynamicDataArgument): + # We'll deal with dynamic data (e.g. secrets) when launching the container. + input_artifact_node = input_argument else: raise ApiServiceError( f"Unexpected task argument: {input_spec.name}={input_argument}. {child_task_spec=}" diff --git a/cloud_pipelines_backend/backend_types_sql.py b/cloud_pipelines_backend/backend_types_sql.py index af16b3c..dca54c0 100644 --- a/cloud_pipelines_backend/backend_types_sql.py +++ b/cloud_pipelines_backend/backend_types_sql.py @@ -406,6 +406,7 @@ class ExecutionNode(_TableBase): EXECUTION_NODE_EXTRA_DATA_ORCHESTRATION_ERROR_MESSAGE_KEY = ( "orchestration_error_message" ) +EXECUTION_NODE_EXTRA_DATA_DYNAMIC_DATA_ARGUMENTS_KEY = "dynamic_data_arguments" CONTAINER_EXECUTION_EXTRA_DATA_ORCHESTRATION_ERROR_MESSAGE_KEY = ( "orchestration_error_message" ) @@ -476,3 +477,13 @@ class PipelineRunAnnotation(_TableBase): pipeline_run: orm.Mapped[PipelineRun] = orm.relationship(repr=False, init=False) key: orm.Mapped[str] = orm.mapped_column(default=None, primary_key=True) value: orm.Mapped[str | None] = orm.mapped_column(default=None) + + +class Secret(_TableBase): + __tablename__ = "secret" + user_id: orm.Mapped[str] = orm.mapped_column(primary_key=True, index=True) + secret_name: orm.Mapped[str] = orm.mapped_column(primary_key=True) + secret_value: orm.Mapped[str] + created_at: orm.Mapped[datetime.datetime | None] = orm.mapped_column(default=None) + updated_at: orm.Mapped[datetime.datetime | None] = orm.mapped_column(default=None) + extra_data: orm.Mapped[dict[str, Any] | None] = orm.mapped_column(default=None) diff --git a/cloud_pipelines_backend/component_structures.py b/cloud_pipelines_backend/component_structures.py index 12a26e0..1bc4550 100644 --- a/cloud_pipelines_backend/component_structures.py +++ b/cloud_pipelines_backend/component_structures.py @@ -317,7 +317,26 @@ class TaskOutputArgument(_BaseModel): # Has additional constructor for convenie task_output: TaskOutputReference -ArgumentType = Union[PrimitiveTypes, GraphInputArgument, TaskOutputArgument] +DynamicDataReference = str | dict[str, Any] + + +@dataclasses.dataclass +class DynamicDataArgument(_BaseModel): + """Argument that references data that's dynamically produced by the execution system at runtime. + + Examples of dynamic data: + * Secret value + * Container execution ID + * Pipeline run ID + * Loop index/item + """ + + dynamic_data: DynamicDataReference + + +ArgumentType = Union[ + PrimitiveTypes, GraphInputArgument, TaskOutputArgument, DynamicDataArgument +] @dataclasses.dataclass diff --git a/cloud_pipelines_backend/launchers/interfaces.py b/cloud_pipelines_backend/launchers/interfaces.py index 1272e60..f5807f9 100644 --- a/cloud_pipelines_backend/launchers/interfaces.py +++ b/cloud_pipelines_backend/launchers/interfaces.py @@ -36,6 +36,7 @@ class InputArgument: value: str | None = None uri: str | None = None staging_uri: str + is_secret: bool = False class ContainerTaskLauncher(typing.Generic[_TLaunchedContainer], abc.ABC): diff --git a/cloud_pipelines_backend/orchestrator_sql.py b/cloud_pipelines_backend/orchestrator_sql.py index e81f09a..37ebe18 100644 --- a/cloud_pipelines_backend/orchestrator_sql.py +++ b/cloud_pipelines_backend/orchestrator_sql.py @@ -26,6 +26,9 @@ _T = typing.TypeVar("_T") +DYNAMIC_DATA_SECRET_KEY = "secret" +DYNAMIC_DATA_SECRET_NAME_KEY = "name" + class OrchestratorError(RuntimeError): pass @@ -436,6 +439,39 @@ def generate_execution_log_uri( for output_spec in component_spec.outputs or [] } + # Handling secrets. + # We read secrets from execution_node.extra_data rather than from task_spec.arguments, + # because some secrets might have been passed from upstream graph inputs. + dynamic_data_arguments: dict[str, dict[str, Any]] = ( + execution.extra_data or {} + ).get(bts.EXECUTION_NODE_EXTRA_DATA_DYNAMIC_DATA_ARGUMENTS_KEY, {}) + secret_hash = "" + for input_name, dynamic_data_argument in dynamic_data_arguments.items(): + if not isinstance(dynamic_data_argument, dict): + continue + dynamic_data_items = list(dynamic_data_argument.items()) + dynamic_data_key, dynamic_data_parameters = dynamic_data_items[0] + if dynamic_data_key == DYNAMIC_DATA_SECRET_KEY: + secret_parameters = dynamic_data_parameters + user_id = pipeline_run.created_by + secret_name = secret_parameters[DYNAMIC_DATA_SECRET_NAME_KEY] + secret = session.get(bts.Secret, (user_id, secret_name)) + if not secret: + raise OrchestratorError( + f"{execution.id=}: User error: Error resolving a secret argument for {input_name=}: User {user_id} does not have secret {secret_name}." + ) + secret_value = secret.secret_value + input_artifact_data[input_name] = bts.ArtifactData( + total_size=len(secret_value.encode("utf-8")), + is_dir=False, + value=secret_value, + uri=None, + # This hash is not used, so we're using a dummy value here that makes it possible to identify the secret arguments in the following code. + hash=secret_hash, + ) + session.rollback() + + # Preparing the launcher input arguments input_arguments = { input_name: launcher_interfaces.InputArgument( total_size=artifact_data.total_size, @@ -447,6 +483,7 @@ def generate_execution_log_uri( execution_id=container_execution_uuid, input_name=input_name, ), + is_secret=(artifact_data.hash == secret_hash), ) for input_name, artifact_data in input_artifact_data.items() } diff --git a/tests/test_secrets.py b/tests/test_secrets.py new file mode 100644 index 0000000..6d4cb67 --- /dev/null +++ b/tests/test_secrets.py @@ -0,0 +1,132 @@ +from typing import Callable +from unittest import mock + +from sqlalchemy import orm + +from cloud_pipelines_backend import api_server_sql +from cloud_pipelines_backend import component_structures +from cloud_pipelines_backend import database_ops +from cloud_pipelines_backend.launchers import interfaces as launcher_interfaces + + +def _initialize_db_and_get_session_factory() -> Callable[[], orm.Session]: + db_engine = database_ops.create_db_engine_and_migrate_db(database_uri="sqlite://") + return lambda: orm.Session(bind=db_engine) + + +def test_running_pipeline_with_secrets(): + user = "user1" + secret_name = "SECRET_1" + secret_value = "SECRET_1_VALUE" + + secret_input_name = "secret_input" + + component_spec = component_structures.ComponentSpec( + inputs=[ + component_structures.InputSpec(name=secret_input_name), + ], + implementation=component_structures.ContainerImplementation( + container=component_structures.ContainerSpec(image="python") + ), + ) + + task_spec1 = component_structures.TaskSpec( + component_ref=component_structures.ComponentReference(spec=component_spec), + arguments={ + secret_input_name: component_structures.DynamicDataArgument( + dynamic_data={"secret": {"name": secret_name}} + ) + }, + ) + + graph_input_name = "graph_input_1" + task_spec2 = component_structures.TaskSpec( + component_ref=component_structures.ComponentReference(spec=component_spec), + arguments={ + secret_input_name: component_structures.GraphInputArgument( + graph_input=component_structures.GraphInputReference( + input_name=graph_input_name + ) + ) + }, + ) + + for task_spec in [task_spec1, task_spec2]: + pipeline_spec = component_structures.ComponentSpec( + inputs=[ + component_structures.InputSpec(name=graph_input_name), + ], + implementation=component_structures.GraphImplementation( + graph=component_structures.GraphSpec( + tasks={ + "task": task_spec, + } + ) + ), + ) + + root_pipeline_task = component_structures.TaskSpec( + component_ref=component_structures.ComponentReference(spec=pipeline_spec), + arguments={ + graph_input_name: component_structures.DynamicDataArgument( + dynamic_data={"secret": {"name": secret_name}} + ) + }, + ) + + session_factory = _initialize_db_and_get_session_factory() + secrets_service = api_server_sql.SecretsApiService() + pipeline_runs_service = api_server_sql.PipelineRunsApiService_Sql() + + secrets_service.create_secret( + session=session_factory(), + user_id=user, + secret_name=secret_name, + secret_value=secret_value, + ) + + list_secrets_response = secrets_service.list_secrets( + session=session_factory(), + user_id=user, + ) + assert list_secrets_response.secrets + assert list_secrets_response.secrets[0].secret_name == secret_name + + pipeline_runs_service.create( + session=session_factory(), + root_task=root_pipeline_task, + created_by=user, + ) + + storage_provider_mock = mock.MagicMock() + launched_container_mock = mock.MagicMock( + status=launcher_interfaces.ContainerStatus.PENDING, + to_dict=lambda: {"foo": "bar"}, + ) + launch_container_task_mock = mock.MagicMock( + return_value=launched_container_mock + ) + launcher_mock = mock.MagicMock(launch_container_task=launch_container_task_mock) + data_root_uri = "file:///tmp/artifacts" + logs_root_uri = "file:///tmp/logs" + + from cloud_pipelines_backend import orchestrator_sql + + orchestrator = orchestrator_sql.OrchestratorService_Sql( + session_factory=session_factory, + launcher=launcher_mock, + storage_provider=storage_provider_mock, + data_root_uri=data_root_uri, + logs_root_uri=logs_root_uri, + ) + orchestrator.process_each_queue_once() + + launch_container_task_mock.assert_called_once() + input_arguments: dict[str, launcher_interfaces.InputArgument] | None = ( + launch_container_task_mock.call_args.kwargs.get("input_arguments") + ) + assert input_arguments + secret_argument = input_arguments.get(secret_input_name) + assert secret_argument + assert secret_argument.value == secret_value + assert secret_argument.is_secret