Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 70 additions & 3 deletions cloud_pipelines_backend/api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def create(
pipeline_run_id=pipeline_run.id,
created_by=created_by,
pipeline_name=pipeline_name,
annotations=annotations,
)
session.commit()

Expand Down Expand Up @@ -338,10 +339,12 @@ def set_annotation(
raise errors.PermissionError(
f"The pipeline run {id} was started by {pipeline_run.created_by} and cannot be changed by {user_name}"
)
pipeline_run_annotation = bts.PipelineRunAnnotation(
pipeline_run_id=id, key=key, value=value
_mirror_single_pipeline_run_annotation(
session=session,
pipeline_run_id=id,
key=key,
value=value,
)
session.merge(pipeline_run_annotation)
session.commit()

def delete_annotation(
Expand Down Expand Up @@ -1339,18 +1342,76 @@ def _truncate_for_annotation(
return value[:max_len]


def _mirror_single_pipeline_run_annotation(
*,
session: orm.Session,
pipeline_run_id: bts.IdType,
key: str,
value: str | None,
) -> None:
"""Write a single user annotation to the PipelineRunAnnotation table.

Applies defense-in-depth system-key guard, None-to-empty-string coercion,
and VARCHAR truncation before upserting the row.
"""
if key.startswith(filter_query_sql.SYSTEM_KEY_PREFIX):
_logger.warning(
f"Skipping annotation key {key!r} for pipeline run {pipeline_run_id}: "
f"keys starting with {filter_query_sql.SYSTEM_KEY_PREFIX!r} are reserved."
)
return

if value is None:
value = ""

value = _truncate_for_annotation(
value=value,
field_name=key,
pipeline_run_id=pipeline_run_id,
)
session.merge(
bts.PipelineRunAnnotation(
pipeline_run_id=pipeline_run_id,
key=key,
value=value,
)
)


def _mirror_pipeline_run_annotations(
*,
session: orm.Session,
pipeline_run_id: bts.IdType,
annotations: dict[str, Any] | None,
) -> None:
"""Mirror user-provided annotations into the PipelineRunAnnotation table."""
if not annotations:
return
for key, value in annotations.items():
str_value = str(value) if value is not None else None
_mirror_single_pipeline_run_annotation(
session=session,
pipeline_run_id=pipeline_run_id,
key=key,
value=str_value,
)


def _mirror_system_annotations(
*,
session: orm.Session,
pipeline_run_id: bts.IdType,
created_by: str | None,
pipeline_name: str | None,
annotations: dict[str, Any] | None = None,
) -> None:
"""Mirror pipeline run fields as system annotations for filter_query search.

Always creates an annotation for every run, even when the source value is
None or empty (stored as ""). This ensures data parity so every run has a
row for each system key.

Also mirrors user-provided annotations via _mirror_pipeline_run_annotations.
"""

# TODO: The original pipeline_run.created_by and the pipeline name stored in
Expand Down Expand Up @@ -1403,6 +1464,12 @@ def _mirror_system_annotations(
)
)

_mirror_pipeline_run_annotations(
session=session,
pipeline_run_id=pipeline_run_id,
annotations=annotations,
)


def _recursively_create_all_executions_and_artifacts_root(
session: orm.Session,
Expand Down
189 changes: 171 additions & 18 deletions tests/test_api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,122 @@ def test_create_mirrors_absent_values_as_empty_string(
)


class TestCreateMirrorsUserAnnotations:
def test_create_mirrors_user_annotations(
self,
session_factory: orm.sessionmaker,
service: api_server_sql.PipelineRunsApiService_Sql,
) -> None:
annotations = {"team": "ml-ops", "project": "search"}
run = _create_run(
session_factory,
service,
root_task=_make_task_spec("my-pipeline"),
annotations=annotations,
)
with session_factory() as session:
mirrored = service.list_annotations(session=session, id=run.id)
assert mirrored["team"] == "ml-ops"
assert mirrored["project"] == "search"

def test_create_mirrors_user_annotations_empty_dict(
self,
session_factory: orm.sessionmaker,
service: api_server_sql.PipelineRunsApiService_Sql,
) -> None:
run = _create_run(
session_factory,
service,
root_task=_make_task_spec("my-pipeline"),
annotations={},
)
with session_factory() as session:
mirrored = service.list_annotations(session=session, id=run.id)
assert set(mirrored.keys()) == {
filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY,
filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME,
}

def test_create_mirrors_user_annotations_none(
self,
session_factory: orm.sessionmaker,
service: api_server_sql.PipelineRunsApiService_Sql,
) -> None:
run = _create_run(
session_factory,
service,
root_task=_make_task_spec("my-pipeline"),
annotations=None,
)
with session_factory() as session:
mirrored = service.list_annotations(session=session, id=run.id)
assert set(mirrored.keys()) == {
filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY,
filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME,
}

def test_create_skips_system_prefix_in_user_annotations(
self,
session_factory: orm.sessionmaker,
service: api_server_sql.PipelineRunsApiService_Sql,
caplog: pytest.LogCaptureFixture,
) -> None:
annotations = {"system/foo": "bar", "valid": "ok"}
with caplog.at_level("WARNING"):
run = _create_run(
session_factory,
service,
root_task=_make_task_spec("my-pipeline"),
annotations=annotations,
)
with session_factory() as session:
mirrored = service.list_annotations(session=session, id=run.id)
assert "valid" in mirrored
assert mirrored["valid"] == "ok"
assert "system/foo" not in mirrored
assert any("system/foo" in r.message for r in caplog.records)

def test_create_mirrors_user_annotations_none_value_as_empty_string(
self,
session_factory: orm.sessionmaker,
service: api_server_sql.PipelineRunsApiService_Sql,
) -> None:
annotations = {"tag": None}
run = _create_run(
session_factory,
service,
root_task=_make_task_spec("my-pipeline"),
annotations=annotations,
)
with session_factory() as session:
mirrored = service.list_annotations(session=session, id=run.id)
assert mirrored["tag"] == ""

def test_create_user_annotations_coexist_with_system(
self,
session_factory: orm.sessionmaker,
service: api_server_sql.PipelineRunsApiService_Sql,
) -> None:
run = _create_run(
session_factory,
service,
root_task=_make_task_spec("my-pipeline"),
created_by="alice",
annotations={"team": "a"},
)
with session_factory() as session:
mirrored = service.list_annotations(session=session, id=run.id)
assert mirrored["team"] == "a"
assert (
mirrored[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY]
== "alice"
)
assert (
mirrored[filter_query_sql.PipelineRunAnnotationSystemKey.PIPELINE_NAME]
== "my-pipeline"
)


class TestPipelineRunAnnotationCrud:
def test_system_annotations_coexist_with_user_annotations(
self, session_factory, service
Expand Down Expand Up @@ -625,35 +741,30 @@ class TestAnnotationValueOverflow:
- create() via _mirror_system_annotations(): long pipeline_name, long created_by
"""

# TODO: set_annotation() currently has no truncation guard for the
# VARCHAR(255) limit on annotation key/value columns. These tests
# document the failure. Fix deferred to a separate PR to avoid
# convoluting the backfill + _mirror_system_annotations fix.

def test_set_annotation_long_value_raises_on_overflow(
def test_set_annotation_long_value_truncated(
self,
mysql_varchar_limit_session_factory: orm.sessionmaker,
service: api_server_sql.PipelineRunsApiService_Sql,
) -> None:
"""set_annotation() with a 300-char value overflows the
VARCHAR(255) column and triggers IntegrityError."""
"""set_annotation() with a 300-char value is truncated to 255
via _mirror_single_annotation()."""
run = _create_run(
mysql_varchar_limit_session_factory,
service,
root_task=_make_task_spec(),
created_by="user1",
)
with mysql_varchar_limit_session_factory() as session:
with pytest.raises(
sqlalchemy.exc.IntegrityError, match="Data too long.*value"
):
service.set_annotation(
session=session,
id=run.id,
key="team",
value="v" * 300,
user_name="user1",
)
service.set_annotation(
session=session,
id=run.id,
key="team",
value="v" * 300,
user_name="user1",
)
with mysql_varchar_limit_session_factory() as session:
annotations = service.list_annotations(session=session, id=run.id)
assert annotations["team"] == "v" * bts._STR_MAX_LENGTH

def test_set_annotation_long_key_raises_on_overflow(
self,
Expand Down Expand Up @@ -715,6 +826,48 @@ def test_create_run_long_created_by_truncated(
annotations = service.list_annotations(session=session, id=run.id)
assert annotations[key] == "u" * bts._STR_MAX_LENGTH

def test_create_truncates_long_user_annotation_value(
self,
mysql_varchar_limit_session_factory: orm.sessionmaker,
service: api_server_sql.PipelineRunsApiService_Sql,
) -> None:
"""create() with a 300-char user annotation value is truncated to 255
via _mirror_pipeline_run_annotations()."""
run = _create_run(
mysql_varchar_limit_session_factory,
service,
root_task=_make_task_spec(),
annotations={"long_val": "x" * 300},
)
with mysql_varchar_limit_session_factory() as session:
annotations = service.list_annotations(session=session, id=run.id)
assert annotations["long_val"] == "x" * bts._STR_MAX_LENGTH


class TestSetAnnotationBehavior:
def test_set_annotation_none_value_stored_as_empty_string(
self,
session_factory: orm.sessionmaker,
service: api_server_sql.PipelineRunsApiService_Sql,
) -> None:
run = _create_run(
session_factory,
service,
root_task=_make_task_spec(),
created_by="user1",
)
with session_factory() as session:
service.set_annotation(
session=session,
id=run.id,
key="tag",
value=None,
user_name="user1",
)
with session_factory() as session:
annotations = service.list_annotations(session=session, id=run.id)
assert annotations["tag"] == ""


class TestFilterQueryApiWiring:
def test_filter_query_validates_invalid_json(self, session_factory, service):
Expand Down
Loading