From 8afa91a5ef8804c67cc268c743d5936674ff8350 Mon Sep 17 00:00:00 2001 From: Ramprasath S Date: Fri, 30 Jan 2026 08:05:25 -0500 Subject: [PATCH] fix: correct Tag class usage in pipeline creation Pipeline creation was failing with Pydantic validation errors when BenchmarkEvaluator attempted to create a new SageMaker Pipeline. This occurred because the code imported Tag from sagemaker.core.resources instead of sagemaker.core.shapes, which is what Pipeline.create() expects for its tags parameter. Root Cause: The SDK has two different Tag classes: - sagemaker.core.resources.Tag: Used for Tag.get_all() operations - sagemaker.core.shapes.Tag: Used for Pipeline.create() parameter Changes: - Import Tag from sagemaker.core.shapes for Pipeline.create() - Import Tag as ResourceTag from sagemaker.core.resources for Tag.get_all() - Create proper Tag objects instead of dicts - Add error handling for tag conversion - Update Tag.get_all() calls to use ResourceTag Impact: This fixes benchmark evaluation failures (MMLU_PRO, BBH, GPQA, etc.) when creating new pipelines. Testing: Verified both creating new pipeline and reusing existing pipeline. --- .../src/sagemaker/train/evaluate/execution.py | 44 +++++++++++++++---- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/evaluate/execution.py b/sagemaker-train/src/sagemaker/train/evaluate/execution.py index 3d217b08cf..e2388ef313 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/execution.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/execution.py @@ -18,7 +18,9 @@ from pydantic import BaseModel, Field from sagemaker.core.common_utils import TagsDict from sagemaker.core.helper.session_helper import Session -from sagemaker.core.resources import Pipeline, PipelineExecution, Tag +from sagemaker.core.resources import Pipeline, PipelineExecution +from sagemaker.core.resources import Tag as ResourceTag # For Tag.get_all() +from sagemaker.core.shapes import Tag # For Pipeline.create() tags parameter from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature @@ -68,9 +70,33 @@ def _create_evaluation_pipeline( resolved_pipeline_definition = template.render(pipeline_name=pipeline_name) # Create tags for the pipeline - tags.extend([ - {"key": _TAG_SAGEMAKER_MODEL_EVALUATION, "value": "true"} - ]) + # Note: Tags must be Tag objects, not dicts, for Pydantic validation to pass + tag_objects = [] + + # Add evaluation tag + tag_objects.append(Tag(key=_TAG_SAGEMAKER_MODEL_EVALUATION, value="true")) + + # Process any additional tags passed in + if tags: + for i, tag_item in enumerate(tags): + try: + if hasattr(tag_item, '__class__') and 'Tag' in tag_item.__class__.__name__: + # Already a Tag object + tag_objects.append(tag_item) + elif isinstance(tag_item, dict): + # Convert dict to Tag object - handle both lowercase and capitalized keys + key = tag_item.get("key") or tag_item.get("Key") + value = tag_item.get("value") or tag_item.get("Value") + if key and value: + tag_objects.append(Tag(key=str(key), value=str(value))) + else: + logger.warning(f"Skipping invalid tag at index {i}: {tag_item}") + else: + logger.warning(f"Skipping unsupported tag type at index {i}: {type(tag_item)}") + except Exception as e: + logger.warning(f"Error processing tag at index {i}: {e}") + + logger.info(f"Creating pipeline with {len(tag_objects)} tags") pipeline = Pipeline.create( pipeline_name=pipeline_name, @@ -79,7 +105,7 @@ def _create_evaluation_pipeline( pipeline_definition=resolved_pipeline_definition, pipeline_display_name=f"EvaluationPipeline-{eval_type.value}", pipeline_description=f"Pipeline for {eval_type.value} evaluation jobs", - tags=tags, + tags=tag_objects, session=session, region=region ) @@ -205,8 +231,8 @@ def _get_or_create_pipeline( for pipeline in pipelines: pipeline_arn = pipeline.pipeline_arn - # Get tags using Tag.get_all - tags_list = Tag.get_all(resource_arn=pipeline_arn, session=session, region=region) + # Get tags using ResourceTag.get_all + tags_list = ResourceTag.get_all(resource_arn=pipeline_arn, session=session, region=region) tags = {tag.key: tag.value for tag in tags_list} # Validate tag @@ -647,8 +673,8 @@ def get_all( try: pipeline_arn = pipeline.pipeline_arn - # Get tags using Tag.get_all - tags_list = Tag.get_all(resource_arn=pipeline_arn, session=session, region=region) + # Get tags using ResourceTag.get_all + tags_list = ResourceTag.get_all(resource_arn=pipeline_arn, session=session, region=region) tags = {tag.key: tag.value for tag in tags_list} # Validate tag - only process evaluation pipelines