Skip to content

Commit c51aa4a

Browse files
Validate FT and MD parameter overrides (#845)
2 parents ea3b5e4 + 8b5c171 commit c51aa4a

File tree

11 files changed

+532
-57
lines changed

11 files changed

+532
-57
lines changed

ads/aqua/common/utils.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,3 +755,72 @@ def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
755755

756756
def is_service_managed_container(container):
757757
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
758+
759+
760+
def get_params_list(params: str) -> List[str]:
761+
"""Parses the string parameter and returns a list of params.
762+
763+
Parameters
764+
----------
765+
params
766+
string parameters by separated by -- delimiter
767+
768+
Returns
769+
-------
770+
list of params
771+
772+
"""
773+
if not params:
774+
return []
775+
return ["--" + param.strip() for param in params.split("--")[1:]]
776+
777+
778+
def get_params_dict(params: Union[str, List[str]]) -> dict:
779+
"""Accepts a string or list of string of double-dash parameters and returns a dict with the parameter keys and values.
780+
781+
Parameters
782+
----------
783+
params:
784+
List of parameters or parameter string separated by space.
785+
786+
Returns
787+
-------
788+
dict containing parameter keys and values
789+
790+
"""
791+
params_list = get_params_list(params) if isinstance(params, str) else params
792+
return {
793+
split_result[0]: split_result[1] if len(split_result) > 1 else UNKNOWN
794+
for split_result in (x.split() for x in params_list)
795+
}
796+
797+
798+
def get_combined_params(params1: str = None, params2: str = None) -> str:
799+
"""
800+
Combines string of double-dash parameters, and overrides the values from the second string in the first.
801+
Parameters
802+
----------
803+
params1:
804+
Parameter string with values
805+
params2:
806+
Parameter string with values that need to be overridden.
807+
808+
Returns
809+
-------
810+
A combined list with overridden values from params2.
811+
"""
812+
if not params1:
813+
return params2
814+
if not params2:
815+
return params1
816+
817+
# overwrite values from params2 into params1
818+
combined_params = [
819+
f"{key} {value}" if value else key
820+
for key, value in {
821+
**get_params_dict(params1),
822+
**get_params_dict(params2),
823+
}.items()
824+
]
825+
826+
return " ".join(combined_params)

ads/aqua/extension/deployment_handler.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ def post(self, *args, **kwargs):
9494
description = input_data.get("description")
9595
instance_count = input_data.get("instance_count")
9696
bandwidth_mbps = input_data.get("bandwidth_mbps")
97+
web_concurrency = input_data.get("web_concurrency")
98+
server_port = input_data.get("server_port")
99+
health_check_port = input_data.get("health_check_port")
100+
env_var = input_data.get("env_var")
97101

98102
self.finish(
99103
AquaDeploymentApp().create(
@@ -108,6 +112,10 @@ def post(self, *args, **kwargs):
108112
access_log_id=access_log_id,
109113
predict_log_id=predict_log_id,
110114
bandwidth_mbps=bandwidth_mbps,
115+
web_concurrency=web_concurrency,
116+
server_port=server_port,
117+
health_check_port=health_check_port,
118+
env_var=env_var,
111119
)
112120
)
113121

@@ -194,29 +202,59 @@ def post(self, *args, **kwargs):
194202

195203

196204
class AquaDeploymentParamsHandler(AquaAPIhandler):
197-
"""Handler for Aqua finetuning params REST APIs.
205+
"""Handler for Aqua deployment params REST APIs.
198206
199207
Methods
200208
-------
201209
get(self, model_id)
202210
Retrieves a list of model deployment parameters.
211+
post(self, *args, **kwargs)
212+
Validates parameters for the given model id.
203213
"""
204214

205215
@handle_exceptions
206216
def get(self, model_id):
207217
"""Handle GET request."""
208-
model_id = model_id.split("/")[0]
209218
instance_shape = self.get_argument("instance_shape")
210219
return self.finish(
211220
AquaDeploymentApp().get_deployment_default_params(
212221
model_id=model_id, instance_shape=instance_shape
213222
)
214223
)
215224

225+
@handle_exceptions
226+
def post(self, *args, **kwargs):
227+
"""Handles post request for the deployment param handler API.
228+
229+
Raises
230+
------
231+
HTTPError
232+
Raises HTTPError if inputs are missing or are invalid.
233+
"""
234+
try:
235+
input_data = self.get_json_body()
236+
except Exception:
237+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
238+
239+
if not input_data:
240+
raise HTTPError(400, Errors.NO_INPUT_DATA)
241+
242+
model_id = input_data.get("model_id")
243+
if not model_id:
244+
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model_id"))
245+
246+
params = input_data.get("params")
247+
return self.finish(
248+
AquaDeploymentApp().validate_deployment_params(
249+
model_id=model_id,
250+
params=params,
251+
)
252+
)
253+
216254

217255
__handlers__ = [
218-
("deployments/?([^/]*)", AquaDeploymentHandler),
256+
("deployments/?([^/]*)/params", AquaDeploymentParamsHandler),
219257
("deployments/config/?([^/]*)", AquaDeploymentHandler),
220-
("deployments/?([^/]*/params)", AquaDeploymentParamsHandler),
258+
("deployments/?([^/]*)", AquaDeploymentHandler),
221259
("inference", AquaDeploymentInferenceHandler),
222260
]

ads/aqua/extension/finetune_handler.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,33 @@ def get(self, model_id):
7070
AquaFineTuningApp().get_finetuning_default_params(model_id=model_id)
7171
)
7272

73+
@handle_exceptions
74+
def post(self, *args, **kwargs):
75+
"""Handles post request for the finetuning param handler API.
76+
77+
Raises
78+
------
79+
HTTPError
80+
Raises HTTPError if inputs are missing or are invalid.
81+
"""
82+
try:
83+
input_data = self.get_json_body()
84+
except Exception:
85+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
86+
87+
if not input_data:
88+
raise HTTPError(400, Errors.NO_INPUT_DATA)
89+
90+
params = input_data.get("params", None)
91+
return self.finish(
92+
AquaFineTuningApp().validate_finetuning_params(
93+
params=params,
94+
)
95+
)
96+
7397

7498
__handlers__ = [
99+
("finetuning/?([^/]*)/params", AquaFineTuneParamsHandler),
75100
("finetuning/?([^/]*)", AquaFineTuneHandler),
76101
("finetuning/config/?([^/]*)", AquaFineTuneHandler),
77-
("finetuning/([^/]*)/params", AquaFineTuneParamsHandler),
78102
]

ads/aqua/finetuning/finetuning.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
from ads.aqua.app import AquaApp
1919
from ads.aqua.common.enums import Resource, Tags
2020
from ads.aqua.common.errors import AquaFileExistsError, AquaValueError
21-
from ads.aqua.common.utils import get_container_image, upload_local_to_os
21+
from ads.aqua.common.utils import (
22+
get_container_image,
23+
upload_local_to_os,
24+
get_params_dict,
25+
)
2226
from ads.aqua.constants import (
2327
DEFAULT_FT_BATCH_SIZE,
2428
DEFAULT_FT_BLOCK_STORAGE_SIZE,
@@ -583,3 +587,32 @@ def get_finetuning_default_params(self, model_id: str) -> List[str]:
583587
default_params.append(f"--{name} {str(value).lower()}")
584588

585589
return default_params
590+
591+
def validate_finetuning_params(self, params: List[str] = None) -> Dict:
592+
"""Validate if the fine-tuning parameters passed by the user can be overridden. Parameter values are not
593+
validated, only param keys are validated.
594+
595+
Parameters
596+
----------
597+
params : List[str], optional
598+
Params passed by the user.
599+
600+
Returns
601+
-------
602+
Return a list of restricted params.
603+
"""
604+
restricted_params = []
605+
if params:
606+
dataclass_fields = {field.name for field in fields(AquaFineTuningParams)}
607+
params_dict = get_params_dict(params)
608+
for key, items in params_dict.items():
609+
key = key.lstrip("--")
610+
if key not in dataclass_fields:
611+
restricted_params.append(key)
612+
613+
if restricted_params:
614+
raise AquaValueError(
615+
f"Parameters {restricted_params} are set by Aqua "
616+
f"and cannot be overridden or are invalid."
617+
)
618+
return dict(valid=True)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
"""
7+
aqua.modeldeployment.constants
8+
~~~~~~~~~~~~~~
9+
10+
This module contains constants used in Aqua Model Deployment.
11+
"""
12+
13+
VLLMInferenceRestrictedParams = {"tensor-parallel-size"}

0 commit comments

Comments
 (0)