Skip to content

Commit b75f792

Browse files
authored
Aqua refactor: Move all enums under enums.py for centralized management (#848)
1 parent 8c74952 commit b75f792

File tree

20 files changed

+143
-145
lines changed

20 files changed

+143
-145
lines changed

ads/aqua/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from ads import set_auth
1414
from ads.aqua import logger
15+
from ads.aqua.common.enums import Tags
1516
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1617
from ads.aqua.common.utils import (
1718
UNKNOWN,
@@ -20,7 +21,6 @@
2021
is_valid_ocid,
2122
load_config,
2223
)
23-
from ads.aqua.data import Tags
2424
from ads.common import oci_client as oc
2525
from ads.common.auth import default_signer
2626
from ads.common.utils import extract_region

ads/aqua/common/enums.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,14 @@
1111
from ads.common.extended_enum import ExtendedEnumMeta
1212

1313

14+
class DataScienceResource(str, metaclass=ExtendedEnumMeta):
15+
MODEL_DEPLOYMENT = "datasciencemodeldeployment"
16+
MODEL = "datasciencemodel"
17+
18+
1419
class Resource(str, metaclass=ExtendedEnumMeta):
1520
JOB = "jobs"
21+
JOBRUN = "jobruns"
1622
MODEL = "models"
1723
MODEL_DEPLOYMENT = "modeldeployments"
1824
MODEL_VERSION_SET = "model-version-sets"
@@ -21,3 +27,28 @@ class Resource(str, metaclass=ExtendedEnumMeta):
2127
class DataScienceResource(str, metaclass=ExtendedEnumMeta):
2228
MODEL_DEPLOYMENT = "datasciencemodeldeployment"
2329
MODEL = "datasciencemodel"
30+
31+
32+
class Tags(str, metaclass=ExtendedEnumMeta):
33+
TASK = "task"
34+
LICENSE = "license"
35+
ORGANIZATION = "organization"
36+
AQUA_TAG = "OCI_AQUA"
37+
AQUA_SERVICE_MODEL_TAG = "aqua_service_model"
38+
AQUA_FINE_TUNED_MODEL_TAG = "aqua_fine_tuned_model"
39+
AQUA_MODEL_NAME_TAG = "aqua_model_name"
40+
AQUA_EVALUATION = "aqua_evaluation"
41+
AQUA_FINE_TUNING = "aqua_finetuning"
42+
READY_TO_FINE_TUNE = "ready_to_fine_tune"
43+
READY_TO_IMPORT = "ready_to_import"
44+
BASE_MODEL_CUSTOM = "aqua_custom_base_model"
45+
46+
47+
class RqsAdditionalDetails(str, metaclass=ExtendedEnumMeta):
48+
METADATA = "metadata"
49+
CREATED_BY = "createdBy"
50+
DESCRIPTION = "description"
51+
MODEL_VERSION_SET_ID = "modelVersionSetId"
52+
MODEL_VERSION_SET_NAME = "modelVersionSetName"
53+
PROJECT_ID = "projectId"
54+
VERSION_LABEL = "versionLabel"

ads/aqua/common/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,13 @@
2121
import oci
2222
from oci.data_science.models import JobRun, Model
2323

24+
from ads.aqua.common.enums import RqsAdditionalDetails
2425
from ads.aqua.common.errors import (
2526
AquaFileNotFoundError,
2627
AquaRuntimeError,
2728
AquaValueError,
2829
)
29-
from ads.aqua.constants import (
30-
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
31-
RqsAdditionalDetails,
32-
)
30+
from ads.aqua.constants import SERVICE_MANAGED_CONTAINER_URI_SCHEME
3331
from ads.aqua.data import AquaResourceIdentifier
3432
from ads.common.auth import AuthState, default_signer
3533
from ads.common.extended_enum import ExtendedEnumMeta

ads/aqua/constants.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# Copyright (c) 2024 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55
"""This module defines constants used in ads.aqua module."""
6-
from ads.common.extended_enum import ExtendedEnumMeta
76

87
UNKNOWN = ""
98
UNKNOWN_VALUE = ""
@@ -40,34 +39,3 @@
4039
VALIDATION_METRICS = "validation_metrics"
4140

4241
SERVICE_MANAGED_CONTAINER_URI_SCHEME = "dsmc://"
43-
44-
45-
class RqsAdditionalDetails(str, metaclass=ExtendedEnumMeta):
46-
METADATA = "metadata"
47-
CREATED_BY = "createdBy"
48-
DESCRIPTION = "description"
49-
MODEL_VERSION_SET_ID = "modelVersionSetId"
50-
MODEL_VERSION_SET_NAME = "modelVersionSetName"
51-
PROJECT_ID = "projectId"
52-
VERSION_LABEL = "versionLabel"
53-
54-
55-
class FineTuningDefinedMetadata(str, metaclass=ExtendedEnumMeta):
56-
"""Represents the defined metadata keys used in Fine Tuning."""
57-
58-
VAL_SET_SIZE = "val_set_size"
59-
TRAINING_DATA = "training_data"
60-
61-
62-
class FineTuningCustomMetadata(str, metaclass=ExtendedEnumMeta):
63-
"""Represents the custom metadata keys used in Fine Tuning."""
64-
65-
FT_SOURCE = "fine_tune_source"
66-
FT_SOURCE_NAME = "fine_tune_source_name"
67-
FT_OUTPUT_PATH = "fine_tune_output_path"
68-
FT_JOB_ID = "fine_tune_job_id"
69-
FT_JOB_RUN_ID = "fine_tune_jobrun_id"
70-
TRAINING_METRICS_FINAL = "train_metrics_final"
71-
VALIDATION_METRICS_FINAL = "val_metrics_final"
72-
TRAINING_METRICS_EPOCH = "train_metrics_epoch"
73-
VALIDATION_METRICS_EPOCH = "val_metrics_epoch"

ads/aqua/data.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from dataclasses import dataclass
77

8-
from ads.common.extended_enum import ExtendedEnumMeta
98
from ads.common.serializer import DataClassSerializable
109

1110

@@ -14,46 +13,3 @@ class AquaResourceIdentifier(DataClassSerializable):
1413
id: str = ""
1514
name: str = ""
1615
url: str = ""
17-
18-
19-
class Resource(str, metaclass=ExtendedEnumMeta):
20-
JOB = "jobs"
21-
JOBRUN = "jobruns"
22-
MODEL = "models"
23-
MODEL_DEPLOYMENT = "modeldeployments"
24-
MODEL_VERSION_SET = "model-version-sets"
25-
26-
27-
class DataScienceResource(str, metaclass=ExtendedEnumMeta):
28-
MODEL_DEPLOYMENT = "datasciencemodeldeployment"
29-
MODEL = "datasciencemodel"
30-
31-
32-
class Tags(str, metaclass=ExtendedEnumMeta):
33-
TASK = "task"
34-
LICENSE = "license"
35-
ORGANIZATION = "organization"
36-
AQUA_TAG = "OCI_AQUA"
37-
AQUA_SERVICE_MODEL_TAG = "aqua_service_model"
38-
AQUA_FINE_TUNED_MODEL_TAG = "aqua_fine_tuned_model"
39-
AQUA_MODEL_NAME_TAG = "aqua_model_name"
40-
AQUA_EVALUATION = "aqua_evaluation"
41-
AQUA_FINE_TUNING = "aqua_finetuning"
42-
READY_TO_FINE_TUNE = "ready_to_fine_tune"
43-
READY_TO_IMPORT = "ready_to_import"
44-
BASE_MODEL_CUSTOM = "aqua_custom_base_model"
45-
46-
47-
class InferenceContainerType(str, metaclass=ExtendedEnumMeta):
48-
CONTAINER_TYPE_VLLM = "vllm"
49-
CONTAINER_TYPE_TGI = "tgi"
50-
51-
52-
class InferenceContainerTypeKey(str, metaclass=ExtendedEnumMeta):
53-
AQUA_VLLM_CONTAINER_KEY = "odsc-vllm-serving"
54-
AQUA_TGI_CONTAINER_KEY = "odsc-tgi-serving"
55-
56-
57-
class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
58-
PARAM_TYPE_VLLM = "VLLM_PARAMS"
59-
PARAM_TYPE_TGI = "TGI_PARAMS"

ads/aqua/evaluation/evaluation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
from ads.aqua import logger
2727
from ads.aqua.app import AquaApp
2828
from ads.aqua.common import utils
29-
from ads.aqua.common.enums import *
29+
from ads.aqua.common.enums import (
30+
DataScienceResource,
31+
Resource,
32+
RqsAdditionalDetails,
33+
Tags,
34+
)
3035
from ads.aqua.common.errors import (
3136
AquaFileExistsError,
3237
AquaFileNotFoundError,
@@ -44,7 +49,6 @@
4449
is_valid_ocid,
4550
upload_local_to_os,
4651
)
47-
from ads.aqua.data import Tags
4852
from ads.aqua.evaluation.constants import *
4953
from ads.aqua.evaluation.entities import *
5054
from ads.aqua.evaluation.errors import *

ads/aqua/extension/model_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def post(self, *args, **kwargs):
179179
)
180180

181181
# Check pipeline_tag, it should be `text-generation`
182-
if hf_model_info.pipeline_tag.lower() != ModelTask.TEXT_GENERATION.value:
182+
if hf_model_info.pipeline_tag.lower() != ModelTask.TEXT_GENERATION:
183183
raise AquaRuntimeError(
184184
f"Unsupported pipeline tag for the chosen model: '{hf_model_info.pipeline_tag}'. "
185185
f"AQUA currently supports the following tasks only: {', '.join(ModelTask.values())}. "

ads/aqua/extension/ui_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tornado.web import HTTPError
1010

1111
from ads.aqua.common.decorator import handle_exceptions
12-
from ads.aqua.data import Tags
12+
from ads.aqua.common.enums import Tags
1313
from ads.aqua.extension.base_handler import AquaAPIhandler, Errors
1414
from ads.aqua.extension.utils import validate_function_parameters
1515
from ads.aqua.model.entities import ImportModelDetails

ads/aqua/finetuning/finetuning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
1818
from ads.aqua.app import AquaApp
19+
from ads.aqua.common.enums import Resource, Tags
1920
from ads.aqua.common.errors import AquaFileExistsError, AquaValueError
2021
from ads.aqua.common.utils import (
2122
DEFAULT_FT_BATCH_SIZE,
@@ -28,7 +29,7 @@
2829
get_container_image,
2930
upload_local_to_os,
3031
)
31-
from ads.aqua.data import AquaResourceIdentifier, Resource, Tags
32+
from ads.aqua.data import AquaResourceIdentifier
3233
from ads.aqua.finetuning.constants import *
3334
from ads.aqua.finetuning.entities import *
3435
from ads.common.auth import default_signer

ads/aqua/job.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
import logging
99
from dataclasses import dataclass, field
10-
from ads.common.serializer import DataClassSerializable
10+
1111
from ads.aqua.data import AquaResourceIdentifier
12+
from ads.common.serializer import DataClassSerializable
1213

1314
logger = logging.getLogger(__name__)
1415

0 commit comments

Comments
 (0)