Skip to content

Commit b9618d8

Browse files
authored
[AQUA] Block legacy ft model (#1238)
2 parents d852209 + dcfa11e commit b9618d8

File tree

9 files changed

+176
-35
lines changed

9 files changed

+176
-35
lines changed

ads/aqua/common/utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,14 +643,18 @@ def get_resource_name(ocid: str) -> str:
643643
return name
644644

645645

646-
def get_model_by_reference_paths(model_file_description: dict):
646+
def get_model_by_reference_paths(
647+
model_file_description: dict, is_ft_model_v2: bool = False
648+
):
647649
"""Reads the model file description json dict and returns the base model path and fine-tuned path for
648650
models created by reference.
649651
650652
Parameters
651653
----------
652654
model_file_description: dict
653655
json dict containing model paths and objects for models created by reference.
656+
is_ft_model_v2: bool
657+
Flag to indicate if it's fine tuned model v2. Defaults to False.
654658
655659
Returns
656660
-------
@@ -666,8 +670,18 @@ def get_model_by_reference_paths(model_file_description: dict):
666670
"Please check if the model created by reference has the correct artifact."
667671
)
668672

673+
if is_ft_model_v2:
674+
# model_file_description json for fine tuned model v2 contains only fine tuned model artifacts
675+
# so first model is always the fine tuned model
676+
ft_model_artifact = models[0]
677+
fine_tune_output_path = f"oci://{ft_model_artifact['bucketName']}@{ft_model_artifact['namespace']}/{ft_model_artifact['prefix']}".rstrip(
678+
"/"
679+
)
680+
681+
return UNKNOWN, fine_tune_output_path
682+
669683
if len(models) > 0:
670-
# since the model_file_description json does not have a flag to identify the base model, we consider
684+
# since the model_file_description json for legacy fine tuned model does not have a flag to identify the base model, we consider
671685
# the first instance to be the base model.
672686
base_model_artifact = models[0]
673687
base_model_path = f"oci://{base_model_artifact['bucketName']}@{base_model_artifact['namespace']}/{base_model_artifact['prefix']}".rstrip(

ads/aqua/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
MODEL_FILE_DESCRIPTION_VERSION = "1.0"
4747
MODEL_FILE_DESCRIPTION_TYPE = "modelOSSReferenceDescription"
4848
AQUA_FINE_TUNE_MODEL_VERSION = "v2"
49+
INCLUDE_BASE_MODEL = 1
4950

5051
TRAINING_METRICS_FINAL = "training_metrics_final"
5152
VALIDATION_METRICS_FINAL = "validation_metrics_final"

ads/aqua/finetuning/finetuning.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
upload_local_to_os,
2626
)
2727
from ads.aqua.constants import (
28+
AQUA_FINE_TUNE_MODEL_VERSION,
2829
DEFAULT_FT_BATCH_SIZE,
2930
DEFAULT_FT_BLOCK_STORAGE_SIZE,
3031
DEFAULT_FT_REPLICA,
@@ -306,7 +307,9 @@ def create(
306307
}
307308
# needs to add 'fine_tune_model_version' tag when creating the ft model for the
308309
# ft container to block merging base model artifact with ft model artifact.
309-
ft_model_freeform_tags = {Tags.AQUA_FINE_TUNE_MODEL_VERSION: "v2"}
310+
ft_model_freeform_tags = {
311+
Tags.AQUA_FINE_TUNE_MODEL_VERSION: AQUA_FINE_TUNE_MODEL_VERSION
312+
}
310313

311314
ft_model = self.create_model_catalog(
312315
display_name=create_fine_tuning_details.ft_name,

ads/aqua/model/model.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,10 @@ def create(
172172
The instance of DataScienceModel or DataScienceModelGroup.
173173
"""
174174
fine_tune_weights = []
175+
model_name = ""
175176
if isinstance(model, AquaMultiModelRef):
176177
fine_tune_weights = model.fine_tune_weights
178+
model_name = model.model_name
177179
model = model.model_id
178180

179181
service_model = DataScienceModel.from_id(model)
@@ -194,6 +196,7 @@ def create(
194196
if fine_tune_weights:
195197
custom_model = self._create_model_group(
196198
model_id=model,
199+
model_name=model_name,
197200
compartment_id=target_compartment,
198201
project_id=target_project,
199202
freeform_tags=combined_freeform_tags,
@@ -268,6 +271,7 @@ def _create_model(
268271
def _create_model_group(
269272
self,
270273
model_id: str,
274+
model_name: str,
271275
compartment_id: str,
272276
project_id: str,
273277
freeform_tags: Dict,
@@ -276,6 +280,20 @@ def _create_model_group(
276280
service_model: DataScienceModel,
277281
):
278282
"""Creates a data science model group."""
283+
member_models = [
284+
{
285+
"inference_key": fine_tune_weight.model_name,
286+
"model_id": fine_tune_weight.model_id,
287+
}
288+
for fine_tune_weight in fine_tune_weights
289+
]
290+
# must also include base model info in member models to create stacked model group
291+
member_models.append(
292+
{
293+
"inference_key": model_name or service_model.display_name,
294+
"model_id": model_id,
295+
}
296+
)
279297
custom_model = (
280298
DataScienceModelGroup()
281299
.with_compartment_id(compartment_id)
@@ -286,15 +304,7 @@ def _create_model_group(
286304
.with_defined_tags(**defined_tags)
287305
.with_custom_metadata_list(service_model.custom_metadata_list)
288306
.with_base_model_id(model_id)
289-
.with_member_models(
290-
[
291-
{
292-
"inference_key": fine_tune_weight.model_name,
293-
"model_id": fine_tune_weight.model_id,
294-
}
295-
for fine_tune_weight in fine_tune_weights
296-
]
297-
)
307+
.with_member_models(member_models)
298308
.create()
299309
)
300310

ads/aqua/model/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55

66
from typing import Tuple
77

8+
from ads.aqua.common.enums import Tags
89
from ads.aqua.common.errors import AquaValueError
910
from ads.aqua.common.utils import get_model_by_reference_paths
11+
from ads.aqua.constants import AQUA_FINE_TUNE_MODEL_VERSION
1012
from ads.aqua.finetuning.constants import FineTuneCustomMetadata
1113
from ads.common.object_storage_details import ObjectStorageDetails
1214
from ads.model.datascience_model import DataScienceModel
@@ -34,8 +36,12 @@ def extract_base_model_from_ft(aqua_model: DataScienceModel) -> Tuple[str, str]:
3436
def extract_fine_tune_artifacts_path(aqua_model: DataScienceModel) -> Tuple[str, str]:
3537
"""Extracts the fine tuning source (fine_tune_output_path) and base model path from the DataScienceModel Object"""
3638

39+
is_ft_model_v2 = (
40+
aqua_model.freeform_tags.get(Tags.AQUA_FINE_TUNE_MODEL_VERSION, "").lower()
41+
== AQUA_FINE_TUNE_MODEL_VERSION
42+
)
3743
base_model_path, fine_tune_output_path = get_model_by_reference_paths(
38-
aqua_model.model_file_description
44+
aqua_model.model_file_description, is_ft_model_v2
3945
)
4046

4147
if not fine_tune_output_path or not ObjectStorageDetails.is_oci_path(

ads/aqua/modeldeployment/deployment.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,13 @@ def create(
228228
raise AquaValueError(
229229
"Invalid 'models' provided. Only one base model is required for model stack deployment."
230230
)
231+
self._validate_input_models(create_deployment_details)
231232
model = create_deployment_details.models[0]
233+
else:
234+
try:
235+
create_deployment_details.validate_ft_model_v2(model_id=model)
236+
except ConfigValidationError as err:
237+
raise AquaValueError(f"{err}") from err
232238

233239
service_model_id = model if isinstance(model, str) else model.model_id
234240
logger.debug(
@@ -258,26 +264,9 @@ def create(
258264
)
259265
# TODO: add multi model validation from deployment_type
260266
else:
261-
# Collect all unique model IDs (including fine-tuned models)
262-
source_model_ids = list(
263-
{
264-
model_id
265-
for model in create_deployment_details.models
266-
for model_id in model.all_model_ids()
267-
}
268-
)
269-
logger.debug(
270-
"Fetching source model metadata for model IDs: %s", source_model_ids
267+
source_models, source_model_ids = self._validate_input_models(
268+
create_deployment_details
271269
)
272-
# Fetch source model metadata
273-
source_models = self.get_multi_source(source_model_ids) or {}
274-
275-
try:
276-
create_deployment_details.validate_input_models(
277-
model_details=source_models
278-
)
279-
except ConfigValidationError as err:
280-
raise AquaValueError(f"{err}") from err
281270

282271
base_model_ids = [
283272
model.model_id for model in create_deployment_details.models
@@ -394,6 +383,32 @@ def create(
394383
container_config=container_config,
395384
)
396385

386+
def _validate_input_models(
387+
self,
388+
create_deployment_details: CreateModelDeploymentDetails,
389+
):
390+
"""Validates the base models and associated fine tuned models from 'models' in create_deployment_details for stacked or multi model deployment."""
391+
# Collect all unique model IDs (including fine-tuned models)
392+
source_model_ids = list(
393+
{
394+
model_id
395+
for model in create_deployment_details.models
396+
for model_id in model.all_model_ids()
397+
}
398+
)
399+
logger.debug(
400+
"Fetching source model metadata for model IDs: %s", source_model_ids
401+
)
402+
# Fetch source model metadata
403+
source_models = self.get_multi_source(source_model_ids) or {}
404+
405+
try:
406+
create_deployment_details.validate_input_models(model_details=source_models)
407+
except ConfigValidationError as err:
408+
raise AquaValueError(f"{err}") from err
409+
410+
return source_models, source_model_ids
411+
397412
def _build_model_group_configs(
398413
self,
399414
models: List[AquaMultiModelRef],
@@ -909,6 +924,8 @@ def _create(
909924
params_dict = get_params_dict(params)
910925
# updates `--served-model-name` with service model id
911926
params_dict.update({"--served-model-name": aqua_model.base_model_id})
927+
# TODO: sets `--max-lora-rank` as 32 in params for now, will revisit later
928+
params_dict.update({"--max-lora-rank": 32})
912929
# adds `--enable_lora` to parameters
913930
params_dict.update({"--enable_lora": UNKNOWN})
914931
params = build_params_string(params_dict)

ads/aqua/modeldeployment/entities.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
from ads.aqua.common.enums import Tags
1313
from ads.aqua.common.errors import AquaValueError
1414
from ads.aqua.config.utils.serializer import Serializable
15-
from ads.aqua.constants import UNKNOWN_DICT
15+
from ads.aqua.constants import (
16+
AQUA_FINE_TUNE_MODEL_VERSION,
17+
INCLUDE_BASE_MODEL,
18+
UNKNOWN_DICT,
19+
)
1620
from ads.aqua.data import AquaResourceIdentifier
1721
from ads.aqua.finetuning.constants import FineTuneCustomMetadata
1822
from ads.aqua.modeldeployment.config_loader import (
@@ -509,11 +513,13 @@ def validate_multimodel_deployment_feasibility(
509513

510514
def validate_input_models(self, model_details: Dict[str, DataScienceModel]) -> None:
511515
"""
512-
Validates the input models for a multi-model deployment configuration.
516+
Validates the input models for a stacked-model or multi-model deployment configuration.
513517
514518
Validation Criteria:
515519
- The base model must be explicitly provided.
516520
- The base model must be in 'ACTIVE' state.
521+
- Fine-tuned models must have a tag 'fine_tune_model_version' as v2 to be supported.
522+
- Fine-tuned models must not have custom metadata 'include_base_model_artifact' as 1.
517523
- Fine-tuned model IDs must refer to valid, tagged fine-tuned models.
518524
- Fine-tuned models must refer back to the same base model.
519525
- All model names (including fine-tuned variants) must be unique.
@@ -609,6 +615,8 @@ def validate_input_models(self, model_details: Dict[str, DataScienceModel]) -> N
609615
f"Invalid fine-tuned model ID '{ft_model_id}': missing tag '{Tags.AQUA_FINE_TUNED_MODEL_TAG}'."
610616
)
611617

618+
self.validate_ft_model_v2(model=ft_model)
619+
612620
ft_base_model_id = ft_model.custom_metadata_list.get(
613621
FineTuneCustomMetadata.FINE_TUNE_SOURCE,
614622
ModelCustomMetadataItem(
@@ -650,6 +658,61 @@ def validate_input_models(self, model_details: Dict[str, DataScienceModel]) -> N
650658
f"{', '.join(sorted(duplicate_names))}. Model names must be unique for proper routing in multi-model deployments."
651659
)
652660

661+
def validate_ft_model_v2(
662+
self, model_id: Optional[str] = None, model: Optional[DataScienceModel] = None
663+
) -> None:
664+
"""
665+
Validates the input fine tuned model for model deployment configuration.
666+
667+
Validation Criteria:
668+
- Fine-tuned models must have a tag 'fine_tune_model_version' as v2 to be supported.
669+
- Fine-tuned models must not have custom metadata 'include_base_model_artifact' as '1'.
670+
671+
Parameters
672+
----------
673+
model_id : str
674+
The OCID of DataScienceModel instance.
675+
model : DataScienceModel
676+
The DataScienceModel instance.
677+
678+
Raises
679+
------
680+
ConfigValidationError
681+
If any of the above conditions are violated.
682+
"""
683+
base_model = DataScienceModel.from_id(model_id) if model_id else model
684+
if Tags.AQUA_FINE_TUNED_MODEL_TAG in base_model.freeform_tags:
685+
if (
686+
base_model.freeform_tags.get(
687+
Tags.AQUA_FINE_TUNE_MODEL_VERSION, UNKNOWN
688+
).lower()
689+
!= AQUA_FINE_TUNE_MODEL_VERSION
690+
):
691+
logger.error(
692+
"Validation failed: Fine-tuned model ID '%s' is not supported for model deployment.",
693+
base_model.id,
694+
)
695+
raise ConfigValidationError(
696+
f"Invalid fine-tuned model ID '{base_model.id}': only fine tune model {AQUA_FINE_TUNE_MODEL_VERSION} is supported for model deployment. "
697+
f"Run 'ads aqua model convert_fine_tune --model_id {base_model.id}' to convert legacy AQUA fine tuned model to version {AQUA_FINE_TUNE_MODEL_VERSION} for deployment."
698+
)
699+
700+
include_base_model_artifact = base_model.custom_metadata_list.get(
701+
FineTuneCustomMetadata.FINE_TUNE_INCLUDE_BASE_MODEL_ARTIFACT,
702+
ModelCustomMetadataItem(
703+
key=FineTuneCustomMetadata.FINE_TUNE_INCLUDE_BASE_MODEL_ARTIFACT
704+
),
705+
).value
706+
707+
if include_base_model_artifact == INCLUDE_BASE_MODEL:
708+
logger.error(
709+
"Validation failed: Fine-tuned model ID '%s' is not supported for model deployment.",
710+
base_model.id,
711+
)
712+
raise ConfigValidationError(
713+
f"Invalid fine-tuned model ID '{base_model.id}': for fine tuned models like Phi4, the deployment is not supported. "
714+
)
715+
653716
class Config:
654717
extra = "allow"
655718
protected_namespaces = ()

ads/model/datascience_model_group.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,6 @@ def _build_model_group_details(self) -> dict:
530530
custom_metadata_list=custom_metadata_list,
531531
base_model_id=self.base_model_id,
532532
)
533-
member_model_details.append(MemberModelDetails(model_id=self.base_model_id))
534533
else:
535534
model_group_details = HomogeneousModelGroupDetails(
536535
custom_metadata_list=custom_metadata_list

0 commit comments

Comments
 (0)