66import os
77import traceback
88from dataclasses import fields
9- from typing import Any, Dict, Optional, Union
9+ from datetime import datetime, timedelta
10+ from itertools import chain
11+ from typing import Any, Dict, List, Optional, Union
1012
1113import oci
12- from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
14+ from cachetools import TTLCache, cached
15+ from oci.data_science.models import (
16+ ContainerSummary,
17+ UpdateModelDetails,
18+ UpdateModelProvenanceDetails,
19+ )
1320
1421from ads import set_auth
1522from ads.aqua import logger
2229 is_valid_ocid,
2330 load_config,
2431)
32+ from ads.aqua.config.container_config import (
33+ AquaContainerConfig,
34+ AquaContainerConfigItem,
35+ )
36+ from ads.aqua.constants import SERVICE_MANAGED_CONTAINER_URI_SCHEME
2537from ads.common import oci_client as oc
2638from ads.common.auth import default_signer
2739from ads.common.utils import UNKNOWN, extract_region, is_path_exists
@@ -238,7 +250,9 @@ def create_model_catalog(
238250 .with_custom_metadata_list(model_custom_metadata)
239251 .with_defined_metadata_list(model_taxonomy_metadata)
240252 .with_provenance_metadata(ModelProvenanceMetadata(training_id=UNKNOWN))
241- .with_defined_tags(**(defined_tags or {})) # Create defined tags when a model is created.
253+ .with_defined_tags(
254+ **(defined_tags or {})
255+ ) # Create defined tags when a model is created.
242256 .create(
243257 **kwargs,
244258 )
@@ -269,6 +283,44 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
269283 logger.info(f"Artifact not found in model {model_id}.")
270284 return False
271285
286+ def get_config_from_metadata(
287+ self, model_id: str, metadata_key: str
288+ ) -> ModelConfigResult:
289+ """Gets the config for the given Aqua model from model catalog metadata content.
290+
291+ Parameters
292+ ----------
293+ model_id: str
294+ The OCID of the Aqua model.
295+ metadata_key: str
296+ The metadata key name where artifact content is stored
297+ Returns
298+ -------
299+ ModelConfigResult
300+ A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
301+ """
302+ config = {}
303+ oci_model = self.ds_client.get_model(model_id).data
304+ try:
305+ config = self.ds_client.get_model_defined_metadatum_artifact_content(
306+ model_id, metadata_key
307+ ).data.content.decode("utf-8")
308+ return ModelConfigResult(config=json.loads(config), model_details=oci_model)
309+ except UnicodeDecodeError as ex:
310+ logger.error(
311+ f"Failed to decode content for '{metadata_key}' in defined metadata for model '{model_id}' : {ex}"
312+ )
313+ except json.JSONDecodeError as ex:
314+ logger.error(
315+ f"Invalid JSON format for '{metadata_key}' in defined metadata for model '{model_id}' : {ex}"
316+ )
317+ except Exception as ex:
318+ logger.error(
319+ f"Failed to retrieve defined metadata key '{metadata_key}' for model '{model_id}': {ex}"
320+ )
321+ return ModelConfigResult(config=config, model_details=oci_model)
322+
323+ @cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
272324 def get_config(
273325 self,
274326 model_id: str,
@@ -307,22 +359,7 @@ def get_config(
307359 raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")
308360
309361 config: Dict[str, Any] = {}
310-
311- # if the current model has a service model tag, then
312- if Tags.AQUA_SERVICE_MODEL_TAG in oci_model.freeform_tags:
313- base_model_ocid = oci_model.freeform_tags[Tags.AQUA_SERVICE_MODEL_TAG]
314- logger.info(
315- f"Base model found for the model: {oci_model.id}. "
316- f"Loading {config_file_name} for base model {base_model_ocid}."
317- )
318- if config_folder == ConfigFolder.ARTIFACT:
319- artifact_path = get_artifact_path(oci_model.custom_metadata_list)
320- else:
321- base_model = self.ds_client.get_model(base_model_ocid).data
322- artifact_path = get_artifact_path(base_model.custom_metadata_list)
323- else:
324- logger.info(f"Loading {config_file_name} for model {oci_model.id}...")
325- artifact_path = get_artifact_path(oci_model.custom_metadata_list)
362+ artifact_path = get_artifact_path(oci_model.custom_metadata_list)
326363 if not artifact_path:
327364 logger.debug(
328365 f"Failed to get artifact path from custom metadata for the model: {model_id}"
@@ -337,6 +374,9 @@ def get_config(
337374 config_file_path = os.path.join(config_path, config_file_name)
338375 if is_path_exists(config_file_path):
339376 try:
377+ logger.info(
378+ f"Loading config: `{config_file_name}` from `{config_path}`"
379+ )
340380 config = load_config(
341381 config_path,
342382 config_file_name=config_file_name,
@@ -355,6 +395,85 @@ def get_config(
355395
356396 return ModelConfigResult(config=config, model_details=oci_model)
357397
398+ def get_container_image(self, container_type: str = None) -> str:
399+ """
400+ Gets the latest smc container complete image name from the given container type.
401+
402+ Parameters
403+ ----------
404+ container_type: str
405+ type of container, can be either odsc-vllm-serving, odsc-llm-fine-tuning, odsc-llm-evaluate
406+
407+ Returns
408+ -------
409+ str:
410+ A complete container name along with version. ex: dsmc://odsc-vllm-serving:0.7.4.1
411+ """
412+
413+ containers = self.list_service_containers()
414+ container = next(
415+ (c for c in containers if c.is_latest and c.family_name == container_type),
416+ None,
417+ )
418+ if not container:
419+ raise AquaValueError(f"Invalid container type : {container_type}")
420+ container_image = (
421+ SERVICE_MANAGED_CONTAINER_URI_SCHEME
422+ + container.container_name
423+ + ":"
424+ + container.tag
425+ )
426+ return container_image
427+
428+ @cached(cache=TTLCache(maxsize=20, ttl=timedelta(minutes=30), timer=datetime.now))
429+ def list_service_containers(self) -> List[ContainerSummary]:
430+ """
431+ List containers from containers.conf in OCI Datascience control plane
432+ """
433+ containers = self.ds_client.list_containers().data
434+ return containers
435+
436+ def get_container_config(self) -> AquaContainerConfig:
437+ """
438+ Fetches latest containers from containers.conf in OCI Datascience control plane
439+
440+ Returns
441+ -------
442+ AquaContainerConfig
443+ An Object that contains latest container info for the given container family
444+
445+ """
446+ return AquaContainerConfig.from_service_config(
447+ service_containers=self.list_service_containers()
448+ )
449+
450+ def get_container_config_item(
451+ self, container_family: str
452+ ) -> AquaContainerConfigItem:
453+ """
454+ Fetches latest container for given container_family_name from containers.conf in OCI Datascience control plane
455+
456+ Returns
457+ -------
458+ AquaContainerConfigItem
459+ An Object that contains latest container info for the given container family
460+
461+ """
462+
463+ aqua_container_config = self.get_container_config()
464+ inference_config = aqua_container_config.inference.values()
465+ ft_config = aqua_container_config.finetune.values()
466+ eval_config = aqua_container_config.evaluate.values()
467+ container = next(
468+ (
469+ container
470+ for container in chain(inference_config, ft_config, eval_config)
471+ if container.family.lower() == container_family.lower()
472+ ),
473+ None,
474+ )
475+ return container
476+
358477 @property
359478 def telemetry(self):
360479 if not self._telemetry:
0 commit comments