From 5c4980c9f7919e418db746201f14e0d4e8094aa1 Mon Sep 17 00:00:00 2001 From: kartikmandar Date: Fri, 28 Nov 2025 03:04:08 +0530 Subject: [PATCH] Fix estimation procedure extraction for clustering tasks (#1522) --- openml/tasks/functions.py | 29 +++++++++++------------- tests/test_tasks/test_clustering_task.py | 24 ++++++++++++++++++++ 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/openml/tasks/functions.py b/openml/tasks/functions.py index d2bf5e946..c9e42f7be 100644 --- a/openml/tasks/functions.py +++ b/openml/tasks/functions.py @@ -492,32 +492,29 @@ def _create_task_from_xml(xml: str) -> OpenMLTask: "data_set_id": inputs["source_data"]["oml:data_set"]["oml:data_set_id"], "evaluation_measure": evaluation_measures, } - # TODO: add OpenMLClusteringTask? if task_type in ( TaskType.SUPERVISED_CLASSIFICATION, TaskType.SUPERVISED_REGRESSION, TaskType.LEARNING_CURVE, + TaskType.CLUSTERING, ): - # Convert some more parameters - for parameter in inputs["estimation_procedure"]["oml:estimation_procedure"][ - "oml:parameter" - ]: + est_proc = inputs["estimation_procedure"]["oml:estimation_procedure"] + parameters = est_proc.get("oml:parameter", []) + if isinstance(parameters, dict): + parameters = [parameters] + for parameter in parameters: name = parameter["@name"] text = parameter.get("#text", "") estimation_parameters[name] = text - common_kwargs["estimation_procedure_type"] = inputs["estimation_procedure"][ - "oml:estimation_procedure" - ]["oml:type"] - common_kwargs["estimation_procedure_id"] = int( - inputs["estimation_procedure"]["oml:estimation_procedure"]["oml:id"] - ) - + common_kwargs["estimation_procedure_type"] = est_proc.get("oml:type") + est_proc_id = est_proc.get("oml:id") + common_kwargs["estimation_procedure_id"] = int(est_proc_id) if est_proc_id else None common_kwargs["estimation_parameters"] = estimation_parameters - common_kwargs["target_name"] = inputs["source_data"]["oml:data_set"]["oml:target_feature"] - common_kwargs["data_splits_url"] = inputs["estimation_procedure"][ - "oml:estimation_procedure" - ]["oml:data_splits_url"] + common_kwargs["target_name"] = ( + inputs["source_data"]["oml:data_set"].get("oml:target_feature") or None + ) + common_kwargs["data_splits_url"] = est_proc.get("oml:data_splits_url") cls = { TaskType.SUPERVISED_CLASSIFICATION: OpenMLClassificationTask, diff --git a/tests/test_tasks/test_clustering_task.py b/tests/test_tasks/test_clustering_task.py index dcc024388..48dc35ec8 100644 --- a/tests/test_tasks/test_clustering_task.py +++ b/tests/test_tasks/test_clustering_task.py @@ -36,6 +36,30 @@ def test_download_task(self): assert task.task_type_id == TaskType.CLUSTERING assert task.dataset_id == 36 + @pytest.mark.production() + def test_estimation_procedure_extraction(self): + # task 126033 has complete estimation procedure data + self.use_production_server() + task = openml.tasks.get_task(126033, download_data=False) + + assert task.task_type_id == TaskType.CLUSTERING + assert task.estimation_procedure_id == 17 + + est_proc = task.estimation_procedure + assert est_proc["type"] == "testontrainingdata" + assert est_proc["parameters"] is not None + assert "number_repeats" in est_proc["parameters"] + assert est_proc["data_splits_url"] is not None + + @pytest.mark.production() + def test_estimation_procedure_empty_fields(self): + # task 146714 has empty estimation procedure fields in XML + self.use_production_server() + task = openml.tasks.get_task(self.task_id, download_data=False) + + assert task.task_type_id == TaskType.CLUSTERING + assert task.estimation_procedure_id == 17 + def test_upload_task(self): compatible_datasets = self._get_compatible_rand_dataset() for i in range(100):