Skip to content

Commit ff40286

Browse files
authored
Add option to set logging level (#795)
1 parent 2d74979 commit ff40286

File tree

9 files changed

+174
-37
lines changed

9 files changed

+174
-37
lines changed

ads/aqua/__init__.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,50 @@
55

66

77
import logging
8-
import sys
98
import os
9+
import sys
10+
11+
from ads import set_auth
1012
from ads.aqua.utils import fetch_service_compartment
1113
from ads.config import NB_SESSION_OCID, OCI_RESOURCE_PRINCIPAL_VERSION
12-
from ads import set_auth
1314

14-
logger = logging.getLogger(__name__)
15-
handler = logging.StreamHandler(sys.stdout)
16-
logger.setLevel(logging.INFO)
15+
ENV_VAR_LOG_LEVEL = "ADS_AQUA_LOG_LEVEL"
16+
17+
18+
def get_logger_level():
19+
"""Retrieves logging level from environment variable `LOG_LEVEL`."""
20+
level = os.environ.get(ENV_VAR_LOG_LEVEL, "INFO").upper()
21+
return level
22+
23+
24+
def configure_aqua_logger():
25+
"""Configures the AQUA logger."""
26+
log_level = get_logger_level()
27+
logger = logging.getLogger(__name__)
28+
logger.setLevel(log_level)
29+
30+
handler = logging.StreamHandler(sys.stdout)
31+
formatter = logging.Formatter(
32+
"%(asctime)s - %(name)s.%(module)s - %(levelname)s - %(message)s"
33+
)
34+
handler.setFormatter(formatter)
35+
handler.setLevel(log_level)
36+
37+
logger.addHandler(handler)
38+
logger.propagate = False
39+
return logger
40+
41+
42+
logger = configure_aqua_logger()
43+
44+
45+
def set_log_level(log_level: str):
46+
"""Global for setting logging level."""
47+
48+
log_level = log_level.upper()
49+
logger.setLevel(log_level.upper())
50+
logger.handlers[0].setLevel(log_level)
51+
1752

1853
if OCI_RESOURCE_PRINCIPAL_VERSION:
1954
set_auth("resource_principal")

ads/aqua/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
get_artifact_path,
2020
is_valid_ocid,
2121
load_config,
22-
logger,
2322
)
2423
from ads.common import oci_client as oc
2524
from ads.common.auth import default_signer
@@ -164,7 +163,7 @@ def create_model_version_set(
164163
tag = Tags.AQUA_FINE_TUNING.value
165164

166165
if not model_version_set_id:
167-
tag = Tags.AQUA_FINE_TUNING.value # TODO: Fix this
166+
tag = Tags.AQUA_FINE_TUNING.value # TODO: Fix this
168167
try:
169168
model_version_set = ModelVersionSet.from_name(
170169
name=model_version_set_name,

ads/aqua/cli.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,41 @@
33

44
# Copyright (c) 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
import os
67

8+
from ads.aqua import ENV_VAR_LOG_LEVEL, set_log_level
79
from ads.aqua.deployment import AquaDeploymentApp
10+
from ads.aqua.evaluation import AquaEvaluationApp
811
from ads.aqua.finetune import AquaFineTuningApp
912
from ads.aqua.model import AquaModelApp
10-
from ads.aqua.evaluation import AquaEvaluationApp
1113

1214

1315
class AquaCommand:
14-
"""Contains the command groups for project Aqua."""
16+
"""Contains the command groups for project Aqua.
17+
18+
Acts as an entry point for managing different components of the Aqua
19+
project including model management, fine-tuning, deployment, and
20+
evaluation.
21+
"""
1522

1623
model = AquaModelApp
1724
fine_tuning = AquaFineTuningApp
1825
deployment = AquaDeploymentApp
1926
evaluation = AquaEvaluationApp
27+
28+
def __init__(
29+
self,
30+
log_level: str = os.environ.get(ENV_VAR_LOG_LEVEL, "ERROR").upper(),
31+
):
32+
"""
33+
Initialize the command line interface settings for the Aqua project.
34+
35+
FLAGS
36+
-----
37+
log_level (str):
38+
Sets the logging level for the application.
39+
Default is retrieved from environment variable `LOG_LEVEL`,
40+
or 'ERROR' if not set. Example values include 'DEBUG', 'INFO',
41+
'WARNING', 'ERROR', and 'CRITICAL'.
42+
"""
43+
set_log_level(log_level)

ads/aqua/evaluation.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,7 +1004,7 @@ def list(
10041004
self._process_evaluation_summary(model=model, jobrun=jobrun)
10051005
)
10061006
except Exception as exc:
1007-
logger.error(
1007+
logger.debug(
10081008
f"Processing evaluation: {model.identifier} generated an exception: {exc}"
10091009
)
10101010
evaluations.append(
@@ -1049,7 +1049,7 @@ def _if_eval_artifact_exist(
10491049
return True if response.status == 200 else False
10501050
except oci.exceptions.ServiceError as ex:
10511051
if ex.status == 404:
1052-
logger.info("Evaluation artifact not found.")
1052+
logger.debug(f"Evaluation artifact not found for {model.identifier}.")
10531053
return False
10541054

10551055
@telemetry(entry_point="plugin=evaluation&action=get_status", name="aqua")
@@ -1595,8 +1595,9 @@ def _build_resource_identifier(
15951595
),
15961596
)
15971597
except Exception as e:
1598-
logger.error(
1599-
f"Failed to construct AquaResourceIdentifier from given id=`{id}`, and name=`{name}`, {str(e)}"
1598+
logger.debug(
1599+
f"Failed to construct AquaResourceIdentifier from given id=`{id}`, and name=`{name}`. "
1600+
f"DEBUG INFO: {str(e)}"
16001601
)
16011602
return AquaResourceIdentifier()
16021603

@@ -1642,7 +1643,7 @@ def _fetch_runtime_params(
16421643
)
16431644
if not params.get(EvaluationConfig.PARAMS):
16441645
raise AquaMissingKeyError(
1645-
"model parameters have not been saved in correct format in model taxonomy.",
1646+
"model parameters have not been saved in correct format in model taxonomy. ",
16461647
service_payload={"params": params},
16471648
)
16481649
# TODO: validate the format of parameters.
@@ -1674,7 +1675,7 @@ def _build_job_identifier(
16741675

16751676
except Exception as e:
16761677
logger.debug(
1677-
f"Failed to get job details from job_run_details: {job_run_details}"
1678+
f"Failed to get job details from job_run_details: {job_run_details} "
16781679
f"DEBUG INFO:{str(e)}"
16791680
)
16801681
return AquaResourceIdentifier()

ads/aqua/finetune.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
UpdateModelProvenanceDetails,
1616
)
1717

18-
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
18+
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
1919
from ads.aqua.base import AquaApp
2020
from ads.aqua.data import AquaResourceIdentifier, Resource, Tags
2121
from ads.aqua.exception import AquaFileExistsError, AquaValueError
@@ -29,7 +29,6 @@
2929
UNKNOWN,
3030
UNKNOWN_DICT,
3131
get_container_image,
32-
logger,
3332
upload_local_to_os,
3433
)
3534
from ads.common.auth import default_signer

ads/aqua/model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from cachetools import TTLCache
1515
from oci.data_science.models import JobRun, Model
1616

17-
from ads.aqua import logger, utils
17+
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger, utils
1818
from ads.aqua.base import AquaApp
1919
from ads.aqua.constants import (
2020
TRAINING_METRICS_FINAL,
@@ -26,7 +26,6 @@
2626
)
2727
from ads.aqua.data import AquaResourceIdentifier, Tags
2828
from ads.aqua.exception import AquaRuntimeError
29-
3029
from ads.aqua.training.exceptions import exit_code_dict
3130
from ads.aqua.utils import (
3231
LICENSE_TXT,
@@ -50,7 +49,6 @@
5049
PROJECT_OCID,
5150
TENANCY_OCID,
5251
)
53-
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
5452
from ads.model import DataScienceModel
5553
from ads.model.model_metadata import MetadataTaxonomyKeys, ModelCustomMetadata
5654
from ads.telemetry import telemetry
@@ -228,7 +226,7 @@ def __post_init__(
228226
).value
229227
except Exception as e:
230228
logger.debug(
231-
f"Failed to extract model hyperparameters from {model.id}:" f"{str(e)}"
229+
f"Failed to extract model hyperparameters from {model.id}: " f"{str(e)}"
232230
)
233231
model_hyperparameters = {}
234232

ads/aqua/utils.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import os
1111
import random
1212
import re
13-
import sys
1413
from enum import Enum
1514
from functools import wraps
1615
from pathlib import Path
@@ -22,22 +21,16 @@
2221
from oci.data_science.models import JobRun, Model
2322

2423
from ads.aqua.constants import RqsAdditionalDetails
25-
from ads.aqua.data import AquaResourceIdentifier, Tags
24+
from ads.aqua.data import AquaResourceIdentifier
2625
from ads.aqua.exception import AquaFileNotFoundError, AquaRuntimeError, AquaValueError
2726
from ads.common.auth import default_signer
2827
from ads.common.object_storage_details import ObjectStorageDetails
2928
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
3029
from ads.common.utils import get_console_link, upload_to_os
31-
from ads.config import (
32-
AQUA_SERVICE_MODELS_BUCKET,
33-
CONDA_BUCKET_NS,
34-
TENANCY_OCID,
35-
)
30+
from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID
3631
from ads.model import DataScienceModel, ModelVersionSet
3732

38-
# TODO: allow the user to setup the logging level?
39-
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
40-
logger = logging.getLogger("ODSC_AQUA")
33+
logger = logging.getLogger("ads.aqua")
4134

4235
UNKNOWN = ""
4336
UNKNOWN_DICT = {}
@@ -145,10 +138,6 @@ def get_status(evaluation_status: str, job_run_status: str = None):
145138
MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"
146139

147140

148-
def get_logger():
149-
return logger
150-
151-
152141
def random_color_generator(word: str):
153142
seed = sum([ord(c) for c in word]) % 13
154143
random.seed(seed)
@@ -235,7 +224,7 @@ def read_file(file_path: str, **kwargs) -> str:
235224
with fsspec.open(file_path, "r", **kwargs.get("auth", {})) as f:
236225
return f.read()
237226
except Exception as e:
238-
logger.error(f"Failed to read file {file_path}. {e}")
227+
logger.debug(f"Failed to read file {file_path}. {e}")
239228
return UNKNOWN
240229

241230

@@ -485,7 +474,7 @@ def _build_resource_identifier(
485474
),
486475
)
487476
except Exception as e:
488-
logger.error(
477+
logger.debug(
489478
f"Failed to construct AquaResourceIdentifier from given id=`{id}`, and name=`{name}`, {str(e)}"
490479
)
491480
return AquaResourceIdentifier()
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2024 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
import logging
8+
import subprocess
9+
from unittest import TestCase
10+
from unittest.mock import patch
11+
12+
from parameterized import parameterized
13+
14+
from ads.aqua.cli import AquaCommand
15+
16+
17+
class TestAquaCLI(TestCase):
18+
"""Tests the AQUA CLI."""
19+
20+
DEFAUL_AQUA_CLI_LOGGING_LEVEL = "ERROR"
21+
logger = logging.getLogger(__name__)
22+
logging.basicConfig(
23+
format="%(asctime)s %(module)s %(levelname)s: %(message)s",
24+
datefmt="%m/%d/%Y %I:%M:%S %p",
25+
level=logging.INFO,
26+
)
27+
28+
def test_entrypoint(self):
29+
"""Tests CLI entrypoint."""
30+
result = subprocess.run(["ads", "aqua", "--help"], capture_output=True)
31+
self.logger.info(f"{self._testMethodName}\n" + result.stderr.decode("utf-8"))
32+
assert result.returncode == 0
33+
34+
@parameterized.expand(
35+
[
36+
("default", None, DEFAUL_AQUA_CLI_LOGGING_LEVEL),
37+
("set logging level", "info", "info"),
38+
]
39+
)
40+
@patch("ads.aqua.cli.set_log_level")
41+
def test_aquacommand(self, name, arg, expected, mock_setting_log):
42+
"""Tests aqua command initailzation."""
43+
if arg:
44+
AquaCommand(arg)
45+
else:
46+
AquaCommand()
47+
mock_setting_log.assert_called_with(expected)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
import unittest
7+
from unittest.mock import MagicMock, patch
8+
9+
from ads.aqua import configure_aqua_logger, get_logger_level, set_log_level
10+
11+
12+
class TestAquaLogging(unittest.TestCase):
13+
DEFAULT_AQUA_LOG_LEVEL = "INFO"
14+
15+
@patch.dict("os.environ", {})
16+
def test_get_logger_level_default(self):
17+
"""Test default log level when environment variable is not set."""
18+
self.assertEqual(get_logger_level(), self.DEFAULT_AQUA_LOG_LEVEL)
19+
20+
@patch.dict("os.environ", {"ADS_AQUA_LOG_LEVEL": "DEBUG"})
21+
def test_get_logger_level_from_env(self):
22+
"""Test log level is correctly read from environment variable."""
23+
self.assertEqual(get_logger_level(), "DEBUG")
24+
25+
@patch("logging.getLogger")
26+
@patch("logging.StreamHandler")
27+
def test_configure_aqua_logger(self, mock_handler, mock_get_logger):
28+
"""Test that logger is correctly configured."""
29+
mock_logger = MagicMock()
30+
mock_get_logger.return_value = mock_logger
31+
32+
logger = configure_aqua_logger()
33+
34+
mock_get_logger.assert_called_once_with("ads.aqua")
35+
mock_logger.setLevel.assert_called_with(self.DEFAULT_AQUA_LOG_LEVEL)
36+
37+
@patch("ads.aqua.logger", create=True)
38+
def test_set_log_level(self, mock_logger):
39+
"""Test that the log level of the logger is set correctly."""
40+
mock_handler = MagicMock()
41+
mock_logger.handlers = [mock_handler]
42+
43+
set_log_level("warning")
44+
45+
mock_logger.setLevel.assert_called_with("WARNING")

0 commit comments

Comments
 (0)