Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions sagemaker-train/src/sagemaker/train/evaluate/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from pydantic import BaseModel, Field
from sagemaker.core.common_utils import TagsDict
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.resources import Pipeline, PipelineExecution, Tag
from sagemaker.core.resources import Pipeline, PipelineExecution
from sagemaker.core.resources import Tag as ResourceTag # For Tag.get_all()
from sagemaker.core.shapes import Tag # For Pipeline.create() tags parameter
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature

Expand Down Expand Up @@ -68,9 +70,33 @@ def _create_evaluation_pipeline(
resolved_pipeline_definition = template.render(pipeline_name=pipeline_name)

# Create tags for the pipeline
tags.extend([
{"key": _TAG_SAGEMAKER_MODEL_EVALUATION, "value": "true"}
])
# Note: Tags must be Tag objects, not dicts, for Pydantic validation to pass
tag_objects = []

# Add evaluation tag
tag_objects.append(Tag(key=_TAG_SAGEMAKER_MODEL_EVALUATION, value="true"))

# Process any additional tags passed in
if tags:
for i, tag_item in enumerate(tags):
try:
if hasattr(tag_item, '__class__') and 'Tag' in tag_item.__class__.__name__:
# Already a Tag object
tag_objects.append(tag_item)
elif isinstance(tag_item, dict):
# Convert dict to Tag object - handle both lowercase and capitalized keys
key = tag_item.get("key") or tag_item.get("Key")
value = tag_item.get("value") or tag_item.get("Value")
if key and value:
tag_objects.append(Tag(key=str(key), value=str(value)))
else:
logger.warning(f"Skipping invalid tag at index {i}: {tag_item}")
else:
logger.warning(f"Skipping unsupported tag type at index {i}: {type(tag_item)}")
except Exception as e:
logger.warning(f"Error processing tag at index {i}: {e}")

logger.info(f"Creating pipeline with {len(tag_objects)} tags")

pipeline = Pipeline.create(
pipeline_name=pipeline_name,
Expand All @@ -79,7 +105,7 @@ def _create_evaluation_pipeline(
pipeline_definition=resolved_pipeline_definition,
pipeline_display_name=f"EvaluationPipeline-{eval_type.value}",
pipeline_description=f"Pipeline for {eval_type.value} evaluation jobs",
tags=tags,
tags=tag_objects,
session=session,
region=region
)
Expand Down Expand Up @@ -205,8 +231,8 @@ def _get_or_create_pipeline(
for pipeline in pipelines:
pipeline_arn = pipeline.pipeline_arn

# Get tags using Tag.get_all
tags_list = Tag.get_all(resource_arn=pipeline_arn, session=session, region=region)
# Get tags using ResourceTag.get_all
tags_list = ResourceTag.get_all(resource_arn=pipeline_arn, session=session, region=region)
tags = {tag.key: tag.value for tag in tags_list}

# Validate tag
Expand Down Expand Up @@ -647,8 +673,8 @@ def get_all(
try:
pipeline_arn = pipeline.pipeline_arn

# Get tags using Tag.get_all
tags_list = Tag.get_all(resource_arn=pipeline_arn, session=session, region=region)
# Get tags using ResourceTag.get_all
tags_list = ResourceTag.get_all(resource_arn=pipeline_arn, session=session, region=region)
tags = {tag.key: tag.value for tag in tags_list}

# Validate tag - only process evaluation pipelines
Expand Down