diff --git a/sagemaker-core/src/sagemaker/core/transformer.py b/sagemaker-core/src/sagemaker/core/transformer.py index 9e7d8b8127..93a35dc075 100644 --- a/sagemaker-core/src/sagemaker/core/transformer.py +++ b/sagemaker-core/src/sagemaker/core/transformer.py @@ -86,6 +86,7 @@ def __init__( base_transform_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + transform_ami_version: Optional[Union[str, PipelineVariable]] = None, ): """Initialize a ``Transformer``. @@ -126,6 +127,15 @@ def __init__( AWS services needed. volume_kms_key (str or PipelineVariable): Optional. KMS key ID for encrypting the volume attached to the ML compute instance (default: None). + transform_ami_version (str or PipelineVariable): Optional. Specifies an option + from a collection of preconfigured Amazon Machine Image (AMI) images. + Each image is configured by Amazon Web Services with a set of software + and driver versions. Valid values include: + + * 'al2-ami-sagemaker-batch-gpu-470' - GPU accelerator with NVIDIA driver 470 + * 'al2-ami-sagemaker-batch-gpu-535' - GPU accelerator with NVIDIA driver 535 + + (default: None). """ self.model_name = model_name self.strategy = strategy @@ -162,6 +172,7 @@ def __init__( TRANSFORM_JOB_ENVIRONMENT_PATH, sagemaker_session=self.sagemaker_session, ) + self.transform_ami_version = transform_ami_version @runnable_by_pipeline def transform( @@ -517,6 +528,9 @@ def _prepare_init_params_from_job_description(cls, job_details): init_params["volume_kms_key"] = getattr( job_details["transform_resources"], "volume_kms_key_id", None ) + init_params["transform_ami_version"] = getattr( + job_details["transform_resources"], "transform_ami_version", None + ) init_params["strategy"] = job_details.get("batch_strategy") if job_details.get("transform_output"): init_params["assemble_with"] = getattr( @@ -584,7 +598,10 @@ def _load_config(self, data, data_type, content_type, compression_type, split_ty ) resource_config = self._prepare_resource_config( - self.instance_count, self.instance_type, self.volume_kms_key + self.instance_count, + self.instance_type, + self.volume_kms_key, + self.transform_ami_version, ) return { @@ -631,13 +648,18 @@ def _prepare_output_config(self, s3_path, kms_key_id, assemble_with, accept): return config - def _prepare_resource_config(self, instance_count, instance_type, volume_kms_key): + def _prepare_resource_config( + self, instance_count, instance_type, volume_kms_key, transform_ami_version=None + ): """Prepare resource config.""" config = {"instance_count": instance_count, "instance_type": instance_type} if volume_kms_key is not None: config["volume_kms_key_id"] = volume_kms_key + if transform_ami_version is not None: + config["transform_ami_version"] = transform_ami_version + return config def _prepare_data_processing(self, input_filter, output_filter, join_source): diff --git a/sagemaker-core/tests/unit/test_transformer.py b/sagemaker-core/tests/unit/test_transformer.py index 1e7f068e54..621df013f1 100644 --- a/sagemaker-core/tests/unit/test_transformer.py +++ b/sagemaker-core/tests/unit/test_transformer.py @@ -67,6 +67,7 @@ def test_init_with_all_params(self, mock_session): base_transform_job_name="test-job", sagemaker_session=mock_session, volume_kms_key="volume-key", + transform_ami_version="al2-ami-sagemaker-batch-gpu-535", ) assert transformer.strategy == "MultiRecord" @@ -77,6 +78,7 @@ def test_init_with_all_params(self, mock_session): assert transformer.max_concurrent_transforms == 4 assert transformer.max_payload == 10 assert transformer.volume_kms_key == "volume-key" + assert transformer.transform_ami_version == "al2-ami-sagemaker-batch-gpu-535" def test_format_inputs_to_input_config(self, mock_session): """Test _format_inputs_to_input_config method""" @@ -179,6 +181,27 @@ def test_prepare_resource_config(self, mock_session): assert config["instance_type"] == "ml.m5.xlarge" assert config["volume_kms_key_id"] == "volume-key" + def test_prepare_resource_config_with_ami_version(self, mock_session): + """Test _prepare_resource_config with transform_ami_version""" + transformer = Transformer( + model_name="test-model", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + config = transformer._prepare_resource_config( + instance_count=2, + instance_type="ml.g4dn.xlarge", + volume_kms_key="volume-key", + transform_ami_version="al2-ami-sagemaker-batch-gpu-535", + ) + + assert config["instance_count"] == 2 + assert config["instance_type"] == "ml.g4dn.xlarge" + assert config["volume_kms_key_id"] == "volume-key" + assert config["transform_ami_version"] == "al2-ami-sagemaker-batch-gpu-535" + def test_prepare_resource_config_no_kms(self, mock_session): """Test _prepare_resource_config without KMS key""" transformer = Transformer( @@ -195,6 +218,7 @@ def test_prepare_resource_config_no_kms(self, mock_session): assert config["instance_count"] == 1 assert config["instance_type"] == "ml.m5.xlarge" assert "volume_kms_key_id" not in config + assert "transform_ami_version" not in config def test_prepare_data_processing_all_params(self, mock_session): """Test _prepare_data_processing with all parameters""" @@ -438,6 +462,87 @@ def test_prepare_init_params_from_job_description(self, mock_session): assert init_params["volume_kms_key"] == "volume-key" assert init_params["base_transform_job_name"] == "test-job-456" + def test_prepare_init_params_from_job_description_with_ami_version(self, mock_session): + """Test _prepare_init_params_from_job_description with transform_ami_version""" + job_details = { + "model_name": "test-model", + "transform_resources": Mock( + instance_count=2, + instance_type="ml.g4dn.xlarge", + volume_kms_key_id="volume-key", + transform_ami_version="al2-ami-sagemaker-batch-gpu-535", + ), + "batch_strategy": "SingleRecord", + "transform_output": Mock( + assemble_with="None", + s3_output_path="s3://bucket/output", + kms_key_id="output-key", + accept="text/csv", + ), + "max_concurrent_transforms": 8, + "max_payload_in_mb": 20, + "transform_job_name": "test-job-789", + } + + init_params = Transformer._prepare_init_params_from_job_description(job_details) + + assert init_params["model_name"] == "test-model" + assert init_params["instance_count"] == 2 + assert init_params["instance_type"] == "ml.g4dn.xlarge" + assert init_params["volume_kms_key"] == "volume-key" + assert init_params["transform_ami_version"] == "al2-ami-sagemaker-batch-gpu-535" + assert init_params["base_transform_job_name"] == "test-job-789" + + def test_init_with_transform_ami_version(self, mock_session): + """Test initialization with transform_ami_version parameter""" + transformer = Transformer( + model_name="test-model", + instance_count=1, + instance_type="ml.g4dn.xlarge", + transform_ami_version="al2-ami-sagemaker-batch-gpu-535", + sagemaker_session=mock_session, + ) + + assert transformer.model_name == "test-model" + assert transformer.instance_count == 1 + assert transformer.instance_type == "ml.g4dn.xlarge" + assert transformer.transform_ami_version == "al2-ami-sagemaker-batch-gpu-535" + + def test_init_without_transform_ami_version(self, mock_session): + """Test initialization without transform_ami_version parameter""" + transformer = Transformer( + model_name="test-model", + instance_count=1, + instance_type="ml.g4dn.xlarge", + sagemaker_session=mock_session, + ) + + assert transformer.transform_ami_version is None + + def test_load_config_with_transform_ami_version(self, mock_session): + """Test _load_config includes transform_ami_version in resource_config""" + transformer = Transformer( + model_name="test-model", + instance_count=2, + instance_type="ml.g4dn.xlarge", + output_path="s3://bucket/output", + transform_ami_version="al2-ami-sagemaker-batch-gpu-535", + sagemaker_session=mock_session, + ) + + config = transformer._load_config( + data="s3://bucket/input", + data_type="S3Prefix", + content_type="text/csv", + compression_type=None, + split_type="Line", + ) + + assert "resource_config" in config + assert config["resource_config"]["instance_count"] == 2 + assert config["resource_config"]["instance_type"] == "ml.g4dn.xlarge" + assert config["resource_config"]["transform_ami_version"] == "al2-ami-sagemaker-batch-gpu-535" + def test_delete_model(self, mock_session): """Test delete_model method""" transformer = Transformer(