diff --git a/openml/cli.py b/openml/cli.py index d0a46e498..e12590911 100644 --- a/openml/cli.py +++ b/openml/cli.py @@ -9,7 +9,7 @@ from typing import Callable from urllib.parse import urlparse -from openml import config +from openml import config, tasks def is_hex(string_: str) -> bool: @@ -299,6 +299,109 @@ def configure_field( # noqa: PLR0913 verbose_set(field, value) +def tasks_list(args: argparse.Namespace) -> None: + """List tasks with optional filtering.""" + # Build filter parameters + kwargs = {} + if args.tag: + kwargs["tag"] = args.tag + if args.task_type: + kwargs["task_type"] = args.task_type + if args.status: + kwargs["status"] = args.status + if args.data_name: + kwargs["data_name"] = args.data_name + + # Fetch tasks + try: + tasks_df = tasks.list_tasks( + offset=args.offset, + size=args.size, + **kwargs, + ) + # Simple table output (reuse _format_output if it exists, otherwise simple print) + if tasks_df.empty: + print("No results found.") + elif args.format == "json": + print(tasks_df.to_json(orient="records", indent=2)) + else: + print(tasks_df.to_string(index=False)) + except Exception as e: # noqa: BLE001 + print(f"Error listing tasks: {e}", file=sys.stderr) + sys.exit(1) + + +def tasks_info(args: argparse.Namespace) -> None: + """Display detailed information about a specific task.""" + try: + task = tasks.get_task(int(args.task_id)) + + # Print basic information + print(f"Task ID: {task.task_id}") + print(f"Task Type: {task.task_type}") + print(f"Dataset ID: {task.dataset_id}") + print(f"Estimation Procedure: {task.estimation_procedure.get('type', 'N/A')}") + if task.evaluation_measure: + print(f"Evaluation Measure: {task.evaluation_measure}") + + # Print target feature if available + if hasattr(task, "target_name") and task.target_name: + print(f"Target Feature: {task.target_name}") + + # Print class labels if available + if hasattr(task, "class_labels") and task.class_labels: + print(f"Number of Classes: {len(task.class_labels)}") + print(f"Class Labels: {', '.join(map(str, task.class_labels[:10]))}") + if len(task.class_labels) > 10: + print(f" ... and {len(task.class_labels) - 10} more") + + except Exception as e: # noqa: BLE001 + print(f"Error fetching task info: {e}", file=sys.stderr) + sys.exit(1) + + +def tasks_search(args: argparse.Namespace) -> None: + """Search tasks by associated dataset name (case-insensitive).""" + try: + # Search by dataset name using the API + tasks_df = tasks.list_tasks(data_name=args.query, size=args.size or 100) + + # If no exact match, do case-insensitive client-side filtering + if tasks_df.empty: + all_tasks = tasks.list_tasks(size=1000) # Get more tasks + # Filter by dataset name if available in the dataframe + if "data_name" in all_tasks.columns: + mask = all_tasks["data_name"].str.contains(args.query, case=False, na=False) + tasks_df = all_tasks[mask].head(args.size or 20) + + if tasks_df.empty: + print(f"No tasks found for dataset matching '{args.query}'") + else: + print(f"Found {len(tasks_df)} task(s) for dataset matching '{args.query}':\n") + if args.format == "json": + print(tasks_df.to_json(orient="records", indent=2)) + else: + print(tasks_df.to_string(index=False)) + except Exception as e: # noqa: BLE001 + print(f"Error searching tasks: {e}", file=sys.stderr) + sys.exit(1) + + +def tasks_handler(args: argparse.Namespace) -> None: + """Route tasks subcommands to appropriate handlers.""" + actions = { + "list": tasks_list, + "info": tasks_info, + "search": tasks_search, + } + action_func = actions.get(args.tasks_action) + if action_func: + action_func(args) + else: + print("Please specify a tasks action: list, info, or search") + sys.exit(1) + + def configure(args: argparse.Namespace) -> None: """Calls the right submenu(s) to edit `args.field` in the configuration file.""" set_functions = { @@ -328,7 +431,7 @@ def not_supported_yet(_: str) -> None: def main() -> None: - subroutines = {"configure": configure} + subroutines = {"configure": configure, "tasks": tasks_handler} parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(dest="subroutine") @@ -360,6 +463,64 @@ def main() -> None: help="The value to set the FIELD to.", ) + # Tasks subparser (NEW) + parser_tasks = subparsers.add_parser( + "tasks", + description="Browse and search OpenML tasks from the command line.", + ) + tasks_subparsers = parser_tasks.add_subparsers(dest="tasks_action") + + # tasks list + parser_tasks_list = tasks_subparsers.add_parser( + "list", + help="List tasks with optional filtering", + ) + parser_tasks_list.add_argument("--offset", type=int, help="Number of tasks to skip") + parser_tasks_list.add_argument("--size", type=int, help="Maximum number of tasks to show") + parser_tasks_list.add_argument("--tag", type=str, help="Filter by tag") + parser_tasks_list.add_argument("--task-type", type=str, help="Filter by task type") + parser_tasks_list.add_argument("--status", type=str, help="Filter by status") + parser_tasks_list.add_argument("--data-name", type=str, help="Filter by dataset name") + parser_tasks_list.add_argument( + "--format", + type=str, + choices=["table", "json"], + default="table", + help="Output format", + ) + parser_tasks_list.add_argument( + "--verbose", + action="store_true", + help="Show all columns", + ) + + # tasks info + parser_tasks_info = tasks_subparsers.add_parser( + "info", + help="Display detailed information about a specific task", + ) + parser_tasks_info.add_argument("task_id", type=str, help="Task ID") + + # tasks search + parser_tasks_search = tasks_subparsers.add_parser( + "search", + help="Search tasks by dataset name (case-insensitive)", + ) + parser_tasks_search.add_argument("query", type=str, help="Dataset name to search for") + parser_tasks_search.add_argument("--size", type=int, help="Maximum number of results") + parser_tasks_search.add_argument( + "--format", + type=str, + choices=["table", "json"], + default="table", + help="Output format", + ) + parser_tasks_search.add_argument( + "--verbose", + action="store_true", + help="Show all columns", + ) + args = parser.parse_args() subroutines.get(args.subroutine, lambda _: parser.print_help())(args) diff --git a/tests/test_openml/test_cli.py b/tests/test_openml/test_cli.py new file mode 100644 index 000000000..544816291 --- /dev/null +++ b/tests/test_openml/test_cli.py @@ -0,0 +1,218 @@ +# License: BSD 3-Clause +from __future__ import annotations + +import argparse +from unittest import mock + +import pandas as pd +import pytest + +from openml import cli + + +class TestTasksCLI: + """Test suite for tasks CLI commands.""" + + @mock.patch("openml.tasks.list_tasks") + def test_tasks_list_basic(self, mock_list): + """Test basic task listing.""" + mock_df = pd.DataFrame({ + "tid": [1, 2, 3], + "task_type": ["Supervised Classification", "Supervised Regression", "Clustering"], + "did": [61, 62, 63], + }) + mock_list.return_value = mock_df + + args = argparse.Namespace( + offset=None, + size=10, + tag=None, + task_type=None, + status=None, + data_name=None, + format="table", + verbose=False, + ) + + cli.tasks_list(args) + mock_list.assert_called_once_with(offset=None, size=10) + + @mock.patch("openml.tasks.list_tasks") + def test_tasks_list_with_filters(self, mock_list): + """Test task listing with filters.""" + mock_df = pd.DataFrame({ + "tid": [1], + "task_type": ["Supervised Classification"], + }) + mock_list.return_value = mock_df + + args = argparse.Namespace( + offset=0, + size=5, + tag="study_14", + task_type="Supervised Classification", + status="active", + data_name=None, + format="table", + verbose=False, + ) + + cli.tasks_list(args) + mock_list.assert_called_once_with( + offset=0, + size=5, + tag="study_14", + task_type="Supervised Classification", + status="active", + ) + + @mock.patch("openml.tasks.list_tasks") + def test_tasks_list_empty_results(self, mock_list): + """Test handling of empty results.""" + mock_list.return_value = pd.DataFrame() + + args = argparse.Namespace( + offset=None, + size=None, + tag=None, + task_type=None, + status=None, + data_name=None, + format="table", + verbose=False, + ) + + cli.tasks_list(args) + mock_list.assert_called_once() + + @mock.patch("openml.tasks.list_tasks") + def test_tasks_list_json_format(self, mock_list): + """Test JSON output format.""" + mock_df = pd.DataFrame({"tid": [1], "task_type": ["Supervised Classification"]}) + mock_list.return_value = mock_df + + args = argparse.Namespace( + offset=None, + size=10, + tag=None, + task_type=None, + status=None, + data_name=None, + format="json", + verbose=False, + ) + + cli.tasks_list(args) + mock_list.assert_called_once() + + @mock.patch("openml.tasks.get_task") + def test_tasks_info(self, mock_get): + """Test task info display.""" + mock_task = mock.Mock() + mock_task.task_id = 1 + mock_task.task_type = "Supervised Classification" + mock_task.dataset_id = 61 + mock_task.estimation_procedure = {"type": "crossvalidation"} + mock_task.evaluation_measure = "predictive_accuracy" + mock_task.target_name = "class" + mock_task.class_labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"] + mock_get.return_value = mock_task + + args = argparse.Namespace(task_id="1") + + cli.tasks_info(args) + mock_get.assert_called_once_with(1) + + @mock.patch("openml.tasks.get_task") + def test_tasks_info_error(self, mock_get): + """Test task info with invalid ID.""" + mock_get.side_effect = Exception("Task not found") + + args = argparse.Namespace(task_id="99999") + + with pytest.raises(SystemExit): + cli.tasks_info(args) + + @mock.patch("openml.tasks.list_tasks") + def test_tasks_search_found(self, mock_list): + """Test search with results.""" + mock_df = pd.DataFrame({ + "tid": [1], + "data_name": ["iris"], + }) + mock_list.return_value = mock_df + + args = argparse.Namespace( + query="iris", + size=20, + format="table", + verbose=False, + ) + + cli.tasks_search(args) + mock_list.assert_called() + + @mock.patch("openml.tasks.list_tasks") + def test_tasks_search_not_found(self, mock_list): + """Test search with no results.""" + mock_list.return_value = pd.DataFrame() + + args = argparse.Namespace( + query="nonexistent", + size=20, + format="table", + verbose=False, + ) + + cli.tasks_search(args) + assert mock_list.call_count >= 1 + + @mock.patch("openml.tasks.list_tasks") + def test_tasks_search_case_insensitive(self, mock_list): + """Test case-insensitive search.""" + # First call returns empty (no exact match) + # Second call returns all tasks for client-side filtering + mock_list.side_effect = [ + pd.DataFrame(), # No exact match + pd.DataFrame({ + "tid": [1, 2], + "data_name": ["Iris", "IRIS-versicolor"], + }), + ] + + args = argparse.Namespace( + query="iris", + size=20, + format="table", + verbose=False, + ) + + cli.tasks_search(args) + assert mock_list.call_count == 2 + + def test_tasks_handler_no_action(self): + """Test tasks handler with no action specified.""" + args = argparse.Namespace(tasks_action=None) + + with pytest.raises(SystemExit): + cli.tasks_handler(args) + + @mock.patch("openml.tasks.list_tasks") + def test_tasks_handler_list_action(self, mock_list): + """Test tasks handler routes list action correctly.""" + mock_list.return_value = pd.DataFrame({"tid": [1], "task_type": ["test"]}) + + args = argparse.Namespace( + tasks_action="list", + offset=None, + size=5, + tag=None, + task_type=None, + status=None, + data_name=None, + format="table", + verbose=False, + ) + + cli.tasks_handler(args) + mock_list.assert_called_once()