Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions sagemaker-core/src/sagemaker/core/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
base_transform_job_name: Optional[str] = None,
sagemaker_session: Optional[Session] = None,
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
transform_ami_version: Optional[Union[str, PipelineVariable]] = None,
):
"""Initialize a ``Transformer``.

Expand Down Expand Up @@ -126,6 +127,15 @@ def __init__(
AWS services needed.
volume_kms_key (str or PipelineVariable): Optional. KMS key ID for encrypting
the volume attached to the ML compute instance (default: None).
transform_ami_version (str or PipelineVariable): Optional. Specifies an option
from a collection of preconfigured Amazon Machine Image (AMI) images.
Each image is configured by Amazon Web Services with a set of software
and driver versions. Valid values include:

* 'al2-ami-sagemaker-batch-gpu-470' - GPU accelerator with NVIDIA driver 470
* 'al2-ami-sagemaker-batch-gpu-535' - GPU accelerator with NVIDIA driver 535

(default: None).
"""
self.model_name = model_name
self.strategy = strategy
Expand Down Expand Up @@ -162,6 +172,7 @@ def __init__(
TRANSFORM_JOB_ENVIRONMENT_PATH,
sagemaker_session=self.sagemaker_session,
)
self.transform_ami_version = transform_ami_version

@runnable_by_pipeline
def transform(
Expand Down Expand Up @@ -517,6 +528,9 @@ def _prepare_init_params_from_job_description(cls, job_details):
init_params["volume_kms_key"] = getattr(
job_details["transform_resources"], "volume_kms_key_id", None
)
init_params["transform_ami_version"] = getattr(
job_details["transform_resources"], "transform_ami_version", None
)
init_params["strategy"] = job_details.get("batch_strategy")
if job_details.get("transform_output"):
init_params["assemble_with"] = getattr(
Expand Down Expand Up @@ -584,7 +598,10 @@ def _load_config(self, data, data_type, content_type, compression_type, split_ty
)

resource_config = self._prepare_resource_config(
self.instance_count, self.instance_type, self.volume_kms_key
self.instance_count,
self.instance_type,
self.volume_kms_key,
self.transform_ami_version,
)

return {
Expand Down Expand Up @@ -631,13 +648,18 @@ def _prepare_output_config(self, s3_path, kms_key_id, assemble_with, accept):

return config

def _prepare_resource_config(self, instance_count, instance_type, volume_kms_key):
def _prepare_resource_config(
self, instance_count, instance_type, volume_kms_key, transform_ami_version=None
):
"""Prepare resource config."""
config = {"instance_count": instance_count, "instance_type": instance_type}

if volume_kms_key is not None:
config["volume_kms_key_id"] = volume_kms_key

if transform_ami_version is not None:
config["transform_ami_version"] = transform_ami_version

return config

def _prepare_data_processing(self, input_filter, output_filter, join_source):
Expand Down
105 changes: 105 additions & 0 deletions sagemaker-core/tests/unit/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_init_with_all_params(self, mock_session):
base_transform_job_name="test-job",
sagemaker_session=mock_session,
volume_kms_key="volume-key",
transform_ami_version="al2-ami-sagemaker-batch-gpu-535",
)

assert transformer.strategy == "MultiRecord"
Expand All @@ -77,6 +78,7 @@ def test_init_with_all_params(self, mock_session):
assert transformer.max_concurrent_transforms == 4
assert transformer.max_payload == 10
assert transformer.volume_kms_key == "volume-key"
assert transformer.transform_ami_version == "al2-ami-sagemaker-batch-gpu-535"

def test_format_inputs_to_input_config(self, mock_session):
"""Test _format_inputs_to_input_config method"""
Expand Down Expand Up @@ -179,6 +181,27 @@ def test_prepare_resource_config(self, mock_session):
assert config["instance_type"] == "ml.m5.xlarge"
assert config["volume_kms_key_id"] == "volume-key"

def test_prepare_resource_config_with_ami_version(self, mock_session):
"""Test _prepare_resource_config with transform_ami_version"""
transformer = Transformer(
model_name="test-model",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)

config = transformer._prepare_resource_config(
instance_count=2,
instance_type="ml.g4dn.xlarge",
volume_kms_key="volume-key",
transform_ami_version="al2-ami-sagemaker-batch-gpu-535",
)

assert config["instance_count"] == 2
assert config["instance_type"] == "ml.g4dn.xlarge"
assert config["volume_kms_key_id"] == "volume-key"
assert config["transform_ami_version"] == "al2-ami-sagemaker-batch-gpu-535"

def test_prepare_resource_config_no_kms(self, mock_session):
"""Test _prepare_resource_config without KMS key"""
transformer = Transformer(
Expand All @@ -195,6 +218,7 @@ def test_prepare_resource_config_no_kms(self, mock_session):
assert config["instance_count"] == 1
assert config["instance_type"] == "ml.m5.xlarge"
assert "volume_kms_key_id" not in config
assert "transform_ami_version" not in config

def test_prepare_data_processing_all_params(self, mock_session):
"""Test _prepare_data_processing with all parameters"""
Expand Down Expand Up @@ -438,6 +462,87 @@ def test_prepare_init_params_from_job_description(self, mock_session):
assert init_params["volume_kms_key"] == "volume-key"
assert init_params["base_transform_job_name"] == "test-job-456"

def test_prepare_init_params_from_job_description_with_ami_version(self, mock_session):
"""Test _prepare_init_params_from_job_description with transform_ami_version"""
job_details = {
"model_name": "test-model",
"transform_resources": Mock(
instance_count=2,
instance_type="ml.g4dn.xlarge",
volume_kms_key_id="volume-key",
transform_ami_version="al2-ami-sagemaker-batch-gpu-535",
),
"batch_strategy": "SingleRecord",
"transform_output": Mock(
assemble_with="None",
s3_output_path="s3://bucket/output",
kms_key_id="output-key",
accept="text/csv",
),
"max_concurrent_transforms": 8,
"max_payload_in_mb": 20,
"transform_job_name": "test-job-789",
}

init_params = Transformer._prepare_init_params_from_job_description(job_details)

assert init_params["model_name"] == "test-model"
assert init_params["instance_count"] == 2
assert init_params["instance_type"] == "ml.g4dn.xlarge"
assert init_params["volume_kms_key"] == "volume-key"
assert init_params["transform_ami_version"] == "al2-ami-sagemaker-batch-gpu-535"
assert init_params["base_transform_job_name"] == "test-job-789"

def test_init_with_transform_ami_version(self, mock_session):
"""Test initialization with transform_ami_version parameter"""
transformer = Transformer(
model_name="test-model",
instance_count=1,
instance_type="ml.g4dn.xlarge",
transform_ami_version="al2-ami-sagemaker-batch-gpu-535",
sagemaker_session=mock_session,
)

assert transformer.model_name == "test-model"
assert transformer.instance_count == 1
assert transformer.instance_type == "ml.g4dn.xlarge"
assert transformer.transform_ami_version == "al2-ami-sagemaker-batch-gpu-535"

def test_init_without_transform_ami_version(self, mock_session):
"""Test initialization without transform_ami_version parameter"""
transformer = Transformer(
model_name="test-model",
instance_count=1,
instance_type="ml.g4dn.xlarge",
sagemaker_session=mock_session,
)

assert transformer.transform_ami_version is None

def test_load_config_with_transform_ami_version(self, mock_session):
"""Test _load_config includes transform_ami_version in resource_config"""
transformer = Transformer(
model_name="test-model",
instance_count=2,
instance_type="ml.g4dn.xlarge",
output_path="s3://bucket/output",
transform_ami_version="al2-ami-sagemaker-batch-gpu-535",
sagemaker_session=mock_session,
)

config = transformer._load_config(
data="s3://bucket/input",
data_type="S3Prefix",
content_type="text/csv",
compression_type=None,
split_type="Line",
)

assert "resource_config" in config
assert config["resource_config"]["instance_count"] == 2
assert config["resource_config"]["instance_type"] == "ml.g4dn.xlarge"
assert config["resource_config"]["transform_ami_version"] == "al2-ami-sagemaker-batch-gpu-535"

def test_delete_model(self, mock_session):
"""Test delete_model method"""
transformer = Transformer(
Expand Down