diff --git a/openml/cli.py b/openml/cli.py index d0a46e498..f9538a48d 100644 --- a/openml/cli.py +++ b/openml/cli.py @@ -6,10 +6,14 @@ import string import sys from pathlib import Path -from typing import Callable +from typing import TYPE_CHECKING, Callable from urllib.parse import urlparse from openml import config +from openml.runs import functions as run_functions + +if TYPE_CHECKING: + import pandas as pd def is_hex(string_: str) -> bool: @@ -327,12 +331,268 @@ def not_supported_yet(_: str) -> None: set_field_function(args.value) +def _format_runs_output( + runs_df: pd.DataFrame, + output_format: str, + *, + verbose: bool = False, +) -> None: + """Format and print runs output based on requested format. + + Parameters + ---------- + runs_df : pd.DataFrame + DataFrame containing runs information + output_format : str + Output format: 'json', 'table', or 'list' + verbose : bool + Whether to show detailed information + """ + if output_format == "json": + # Convert to JSON format + output = runs_df.to_json(orient="records", indent=2) + print(output) + elif output_format == "table": + _format_runs_table(runs_df, verbose=verbose) + else: # default: simple list + _format_runs_list(runs_df, verbose=verbose) + + +def _format_runs_table(runs_df: pd.DataFrame, *, verbose: bool = False) -> None: + """Format runs as a table. + + Parameters + ---------- + runs_df : pd.DataFrame + DataFrame containing runs information + verbose : bool + Whether to show all columns + """ + if verbose: + print(runs_df.to_string(index=False)) + else: + # Show only key columns for compact view + columns_to_show = ["run_id", "task_id", "flow_id", "uploader", "upload_time"] + available_columns = [col for col in columns_to_show if col in runs_df.columns] + print(runs_df[available_columns].to_string(index=False)) + + +def _format_runs_list(runs_df: pd.DataFrame, *, verbose: bool = False) -> None: + """Format runs as a simple list. + + Parameters + ---------- + runs_df : pd.DataFrame + DataFrame containing runs information + verbose : bool + Whether to show detailed information + """ + if verbose: + # Verbose: show detailed info for each run + for _, run in runs_df.iterrows(): + print(f"Run ID: {run['run_id']}") + print(f" Task ID: {run['task_id']}") + print(f" Flow ID: {run['flow_id']}") + print(f" Setup ID: {run['setup_id']}") + print(f" Uploader: {run['uploader']}") + print(f" Upload Time: {run['upload_time']}") + if run.get("error_message"): + print(f" Error: {run['error_message']}") + print() + else: + # Simple: just list run IDs + for run_id in runs_df["run_id"]: + print(f"{run_id}: Task {runs_df[runs_df['run_id'] == run_id]['task_id'].iloc[0]}") + + +def runs_list(args: argparse.Namespace) -> None: + """List runs with optional filtering. + + Parameters + ---------- + args : argparse.Namespace + Arguments containing filtering criteria: task, flow, uploader, tag, size, offset, format + """ + # Build filter arguments, excluding None values + kwargs = {} + if args.task is not None: + kwargs["task"] = [args.task] + if args.flow is not None: + kwargs["flow"] = [args.flow] + if args.uploader is not None: + kwargs["uploader"] = [args.uploader] + if args.tag is not None: + kwargs["tag"] = args.tag + if args.size is not None: + kwargs["size"] = args.size + if args.offset is not None: + kwargs["offset"] = args.offset + + try: + # Get runs from server + runs_df = run_functions.list_runs(**kwargs) # type: ignore[arg-type] + + if runs_df.empty: + print("No runs found matching the criteria.") + return + + # Format output based on requested format + _format_runs_output(runs_df, args.format, verbose=args.verbose) + + except Exception as e: # noqa: BLE001 + print(f"Error listing runs: {e}", file=sys.stderr) + sys.exit(1) + + +def _print_run_evaluations(run: object) -> None: + """Print evaluation information for a run. + + Parameters + ---------- + run : OpenMLRun + The run object containing evaluation data + """ + # Display evaluations if available + if hasattr(run, "evaluations") and run.evaluations: + print("\nEvaluations:") + for measure, value in run.evaluations.items(): + print(f" {measure}: {value}") + + # Display fold evaluations if available (summary) + if hasattr(run, "fold_evaluations") and run.fold_evaluations: + print("\nFold Evaluations (Summary):") + for measure, repeats in run.fold_evaluations.items(): + # Calculate average across all folds and repeats + all_values = [] + for repeat_dict in repeats.values(): + all_values.extend(repeat_dict.values()) + if all_values: + avg_value = sum(all_values) / len(all_values) + print(f" {measure}: {avg_value:.4f} (avg over {len(all_values)} folds)") + + +def runs_info(args: argparse.Namespace) -> None: + """Display detailed information about a specific run. + + Parameters + ---------- + args : argparse.Namespace + Arguments containing the run_id to fetch + """ + try: + # Get run from server + run = run_functions.get_run(args.run_id) + + # Display run information + print(f"Run ID: {run.run_id}") + print(f"Task ID: {run.task_id}") + print(f"Task Type: {run.task_type}") + print(f"Flow ID: {run.flow_id}") + print(f"Flow Name: {run.flow_name}") + print(f"Setup ID: {run.setup_id}") + print(f"Dataset ID: {run.dataset_id}") + print(f"Uploader: {run.uploader_name} (ID: {run.uploader})") + + # Display parameter settings if available + if run.parameter_settings: + print("\nParameter Settings:") + for param in run.parameter_settings: + component = param.get("oml:component", "") + name = param.get("oml:name", "") + value = param.get("oml:value", "") + if component: + print(f" {component}.{name}: {value}") + else: + print(f" {name}: {value}") + + # Display evaluations + _print_run_evaluations(run) + + # Display tags if available + if run.tags: + print(f"\nTags: {', '.join(run.tags)}") + + # Display predictions URL if available + if run.predictions_url: + print(f"\nPredictions URL: {run.predictions_url}") + + # Display output files if available + if run.output_files: + print("\nOutput Files:") + for file_name, file_id in run.output_files.items(): + print(f" {file_name}: {file_id}") + + except Exception as e: # noqa: BLE001 + print(f"Error fetching run information: {e}", file=sys.stderr) + sys.exit(1) + + +def runs_download(args: argparse.Namespace) -> None: + """Download a run and cache it locally. + + Parameters + ---------- + args : argparse.Namespace + Arguments containing the run_id to download + """ + try: + # Get run from server (this will download and cache it) + run = run_functions.get_run(args.run_id, ignore_cache=True) + + print(f"Successfully downloaded run {run.run_id}") + print(f"Task ID: {run.task_id}") + print(f"Flow ID: {run.flow_id}") + print(f"Dataset ID: {run.dataset_id}") + + # Display cache location + cache_dir = config.get_cache_directory() + run_cache_dir = Path(cache_dir) / "runs" / str(run.run_id) + if run_cache_dir.exists(): + print(f"\nRun cached at: {run_cache_dir}") + # List cached files + cached_files = list(run_cache_dir.iterdir()) + if cached_files: + print("Cached files:") + for file in cached_files: + print(f" - {file.name}") + + if run.predictions_url: + print(f"\nPredictions available at: {run.predictions_url}") + + except Exception as e: # noqa: BLE001 + print(f"Error downloading run: {e}", file=sys.stderr) + sys.exit(1) + + +def runs(args: argparse.Namespace) -> None: + """Route runs subcommands to the appropriate handler. + + Parameters + ---------- + args : argparse.Namespace + Arguments containing the subcommand and its arguments + """ + subcommands = { + "list": runs_list, + "info": runs_info, + "download": runs_download, + } + + handler = subcommands.get(args.runs_subcommand) + if handler: + handler(args) + else: + print(f"Unknown runs subcommand: {args.runs_subcommand}") + sys.exit(1) + + def main() -> None: - subroutines = {"configure": configure} + subroutines = {"configure": configure, "runs": runs} parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(dest="subroutine") + # Configure subcommand parser_configure = subparsers.add_parser( "configure", description="Set or read variables in your configuration file. For more help also see " @@ -360,6 +620,88 @@ def main() -> None: help="The value to set the FIELD to.", ) + # Runs subcommand + parser_runs = subparsers.add_parser( + "runs", + description="Browse and search OpenML runs from the command line.", + ) + runs_subparsers = parser_runs.add_subparsers(dest="runs_subcommand") + + # runs list subcommand + parser_runs_list = runs_subparsers.add_parser( + "list", + description="List runs with optional filtering.", + help="List runs with optional filtering.", + ) + parser_runs_list.add_argument( + "--task", + type=int, + help="Filter by task ID", + ) + parser_runs_list.add_argument( + "--flow", + type=int, + help="Filter by flow ID", + ) + parser_runs_list.add_argument( + "--uploader", + type=str, + help="Filter by uploader name or ID", + ) + parser_runs_list.add_argument( + "--tag", + type=str, + help="Filter by tag", + ) + parser_runs_list.add_argument( + "--size", + type=int, + default=10, + help="Number of runs to retrieve (default: 10)", + ) + parser_runs_list.add_argument( + "--offset", + type=int, + default=0, + help="Offset for pagination (default: 0)", + ) + parser_runs_list.add_argument( + "--format", + type=str, + choices=["list", "table", "json"], + default="list", + help="Output format (default: list)", + ) + parser_runs_list.add_argument( + "--verbose", + action="store_true", + help="Show detailed information", + ) + + # runs info subcommand + parser_runs_info = runs_subparsers.add_parser( + "info", + description="Display detailed information about a specific run.", + help="Display detailed information about a specific run.", + ) + parser_runs_info.add_argument( + "run_id", + type=int, + help="Run ID to fetch information for", + ) + + # runs download subcommand + parser_runs_download = runs_subparsers.add_parser( + "download", + description="Download a run and cache it locally.", + help="Download a run and cache it locally.", + ) + parser_runs_download.add_argument( + "run_id", + type=int, + help="Run ID to download", + ) + args = parser.parse_args() subroutines.get(args.subroutine, lambda _: parser.print_help())(args) diff --git a/tests/test_openml/test_cli.py b/tests/test_openml/test_cli.py new file mode 100644 index 000000000..87565b460 --- /dev/null +++ b/tests/test_openml/test_cli.py @@ -0,0 +1,491 @@ +# License: BSD 3-Clause +"""Tests for the OpenML CLI commands.""" +from __future__ import annotations + +import argparse +import sys +from io import StringIO +from unittest import mock + +import pandas as pd +import pytest + +from openml import cli +from openml.runs import OpenMLRun +from openml.tasks import TaskType +from openml.testing import TestBase + + +class TestCLIRuns(TestBase): + """Test suite for openml runs CLI commands.""" + + def _create_mock_run(self, run_id: int, task_id: int, flow_id: int) -> OpenMLRun: + """Helper to create a mock OpenMLRun object.""" + return OpenMLRun( + run_id=run_id, + task_id=task_id, + flow_id=flow_id, + dataset_id=1, + setup_id=100 + run_id, + uploader=1, + uploader_name="Test User", + flow_name=f"test.flow.{flow_id}", + task_type="Supervised Classification", + evaluations={"predictive_accuracy": 0.95, "area_under_roc_curve": 0.98}, + fold_evaluations={ + "predictive_accuracy": {0: {0: 0.94, 1: 0.96}}, + }, + parameter_settings=[ + {"oml:name": "n_estimators", "oml:value": "100"}, + {"oml:name": "max_depth", "oml:value": "10", "oml:component": "estimator"}, + ], + tags=["test", "openml-python"], + predictions_url="https://test.openml.org/predictions/12345", + output_files={"predictions": 12345, "description": 12346}, + ) + + def _create_mock_runs_dataframe(self) -> pd.DataFrame: + """Helper to create a mock DataFrame for list_runs.""" + return pd.DataFrame( + { + "run_id": [1, 2, 3], + "task_id": [1, 1, 2], + "flow_id": [100, 101, 100], + "setup_id": [200, 201, 200], + "uploader": [1, 2, 1], + "task_type": [TaskType.SUPERVISED_CLASSIFICATION] * 3, + "upload_time": ["2024-01-01 10:00:00", "2024-01-02 11:00:00", "2024-01-03 12:00:00"], + "error_message": ["", "", ""], + } + ) + + @mock.patch("openml.runs.functions.list_runs") + def test_runs_list_simple(self, mock_list_runs: mock.Mock) -> None: + """Test runs list command with simple output.""" + mock_list_runs.return_value = self._create_mock_runs_dataframe() + + args = argparse.Namespace( + task=None, + flow=None, + uploader=None, + tag=None, + size=10, + offset=0, + format="list", + verbose=False, + ) + + # Capture stdout + captured_output = StringIO() + sys.stdout = captured_output + + cli.runs_list(args) + + sys.stdout = sys.__stdout__ + + output = captured_output.getvalue() + assert "1: Task 1" in output + assert "2: Task 1" in output + assert "3: Task 2" in output + mock_list_runs.assert_called_once_with(size=10, offset=0) + + @mock.patch("openml.runs.functions.list_runs") + def test_runs_list_with_filters(self, mock_list_runs: mock.Mock) -> None: + """Test runs list command with filtering parameters.""" + mock_list_runs.return_value = self._create_mock_runs_dataframe() + + args = argparse.Namespace( + task=1, + flow=100, + uploader="TestUser", + tag="test", + size=20, + offset=10, + format="list", + verbose=False, + ) + + # Capture stdout + captured_output = StringIO() + sys.stdout = captured_output + + cli.runs_list(args) + + sys.stdout = sys.__stdout__ + + mock_list_runs.assert_called_once_with( + task=[1], flow=[100], uploader=["TestUser"], tag="test", size=20, offset=10 + ) + + @mock.patch("openml.runs.functions.list_runs") + def test_runs_list_verbose(self, mock_list_runs: mock.Mock) -> None: + """Test runs list command with verbose output.""" + mock_list_runs.return_value = self._create_mock_runs_dataframe() + + args = argparse.Namespace( + task=None, + flow=None, + uploader=None, + tag=None, + size=10, + offset=0, + format="list", + verbose=True, + ) + + # Capture stdout + captured_output = StringIO() + sys.stdout = captured_output + + cli.runs_list(args) + + sys.stdout = sys.__stdout__ + + output = captured_output.getvalue() + assert "Run ID: 1" in output + assert "Task ID: 1" in output + assert "Flow ID: 100" in output + assert "Setup ID: 200" in output + assert "Uploader: 1" in output + + @mock.patch("openml.runs.functions.list_runs") + def test_runs_list_table_format(self, mock_list_runs: mock.Mock) -> None: + """Test runs list command with table format.""" + mock_list_runs.return_value = self._create_mock_runs_dataframe() + + args = argparse.Namespace( + task=None, + flow=None, + uploader=None, + tag=None, + size=10, + offset=0, + format="table", + verbose=False, + ) + + # Capture stdout + captured_output = StringIO() + sys.stdout = captured_output + + cli.runs_list(args) + + sys.stdout = sys.__stdout__ + + output = captured_output.getvalue() + # Table format should show column headers + assert "run_id" in output + assert "task_id" in output + assert "flow_id" in output + + @mock.patch("openml.runs.functions.list_runs") + def test_runs_list_json_format(self, mock_list_runs: mock.Mock) -> None: + """Test runs list command with JSON format.""" + mock_list_runs.return_value = self._create_mock_runs_dataframe() + + args = argparse.Namespace( + task=None, + flow=None, + uploader=None, + tag=None, + size=10, + offset=0, + format="json", + verbose=False, + ) + + # Capture stdout + captured_output = StringIO() + sys.stdout = captured_output + + cli.runs_list(args) + + sys.stdout = sys.__stdout__ + + output = captured_output.getvalue() + # JSON format should contain valid JSON structure + assert '"run_id":' in output or '"run_id": ' in output + assert '"task_id":' in output or '"task_id": ' in output + + @mock.patch("openml.runs.functions.list_runs") + def test_runs_list_empty_results(self, mock_list_runs: mock.Mock) -> None: + """Test runs list command with no results.""" + mock_list_runs.return_value = pd.DataFrame() + + args = argparse.Namespace( + task=999, + flow=None, + uploader=None, + tag=None, + size=10, + offset=0, + format="list", + verbose=False, + ) + + # Capture stdout + captured_output = StringIO() + sys.stdout = captured_output + + cli.runs_list(args) + + sys.stdout = sys.__stdout__ + + output = captured_output.getvalue() + assert "No runs found" in output + + @mock.patch("openml.runs.functions.list_runs") + def test_runs_list_error_handling(self, mock_list_runs: mock.Mock) -> None: + """Test runs list command error handling.""" + mock_list_runs.side_effect = Exception("Connection error") + + args = argparse.Namespace( + task=None, + flow=None, + uploader=None, + tag=None, + size=10, + offset=0, + format="list", + verbose=False, + ) + + # Capture stderr + captured_error = StringIO() + sys.stderr = captured_error + + with pytest.raises(SystemExit): + cli.runs_list(args) + + sys.stderr = sys.__stderr__ + + error = captured_error.getvalue() + assert "Error listing runs" in error + assert "Connection error" in error + + @mock.patch("openml.runs.functions.get_run") + def test_runs_info(self, mock_get_run: mock.Mock) -> None: + """Test runs info command.""" + mock_run = self._create_mock_run(run_id=12345, task_id=1, flow_id=100) + mock_get_run.return_value = mock_run + + args = argparse.Namespace(run_id=12345) + + # Capture stdout + captured_output = StringIO() + sys.stdout = captured_output + + cli.runs_info(args) + + sys.stdout = sys.__stdout__ + + output = captured_output.getvalue() + assert "Run ID: 12345" in output + assert "Task ID: 1" in output + assert "Flow ID: 100" in output + assert "Flow Name: test.flow.100" in output + assert "Setup ID: 12445" in output # 100 + 12345 + assert "Dataset ID: 1" in output + assert "Uploader: Test User (ID: 1)" in output + assert "Parameter Settings:" in output + assert "n_estimators: 100" in output + assert "estimator.max_depth: 10" in output + assert "Evaluations:" in output + assert "predictive_accuracy: 0.95" in output + assert "area_under_roc_curve: 0.98" in output + assert "Tags: test, openml-python" in output + assert "Predictions URL: https://test.openml.org/predictions/12345" in output + + mock_get_run.assert_called_once_with(12345) + + @mock.patch("openml.runs.functions.get_run") + def test_runs_info_with_fold_evaluations(self, mock_get_run: mock.Mock) -> None: + """Test runs info command displays fold evaluation summary.""" + mock_run = self._create_mock_run(run_id=12345, task_id=1, flow_id=100) + mock_get_run.return_value = mock_run + + args = argparse.Namespace(run_id=12345) + + # Capture stdout + captured_output = StringIO() + sys.stdout = captured_output + + cli.runs_info(args) + + sys.stdout = sys.__stdout__ + + output = captured_output.getvalue() + assert "Fold Evaluations (Summary):" in output + # Average of 0.94 and 0.96 = 0.95 + assert "0.9500" in output or "0.95" in output + + @mock.patch("openml.runs.functions.get_run") + def test_runs_info_error_handling(self, mock_get_run: mock.Mock) -> None: + """Test runs info command error handling.""" + mock_get_run.side_effect = Exception("Run not found") + + args = argparse.Namespace(run_id=99999) + + # Capture stderr + captured_error = StringIO() + sys.stderr = captured_error + + with pytest.raises(SystemExit): + cli.runs_info(args) + + sys.stderr = sys.__stderr__ + + error = captured_error.getvalue() + assert "Error fetching run information" in error + assert "Run not found" in error + + @mock.patch("openml.config.get_cache_directory") + @mock.patch("openml.runs.functions.get_run") + def test_runs_download(self, mock_get_run: mock.Mock, mock_get_cache: mock.Mock) -> None: + """Test runs download command.""" + mock_run = self._create_mock_run(run_id=12345, task_id=1, flow_id=100) + mock_get_run.return_value = mock_run + mock_get_cache.return_value = "/tmp/openml_cache" + + args = argparse.Namespace(run_id=12345) + + # Capture stdout + captured_output = StringIO() + sys.stdout = captured_output + + cli.runs_download(args) + + sys.stdout = sys.__stdout__ + + output = captured_output.getvalue() + assert "Successfully downloaded run 12345" in output + assert "Task ID: 1" in output + assert "Flow ID: 100" in output + assert "Dataset ID: 1" in output + assert "Predictions available at: https://test.openml.org/predictions/12345" in output + + mock_get_run.assert_called_once_with(12345, ignore_cache=True) + + @mock.patch("openml.runs.functions.get_run") + def test_runs_download_error_handling(self, mock_get_run: mock.Mock) -> None: + """Test runs download command error handling.""" + mock_get_run.side_effect = Exception("Download failed") + + args = argparse.Namespace(run_id=12345) + + # Capture stderr + captured_error = StringIO() + sys.stderr = captured_error + + with pytest.raises(SystemExit): + cli.runs_download(args) + + sys.stderr = sys.__stderr__ + + error = captured_error.getvalue() + assert "Error downloading run" in error + assert "Download failed" in error + + def test_runs_dispatcher(self) -> None: + """Test runs command dispatcher.""" + # Test with list subcommand + with mock.patch("openml.cli.runs_list") as mock_list: + args = argparse.Namespace(runs_subcommand="list") + cli.runs(args) + mock_list.assert_called_once_with(args) + + # Test with info subcommand + with mock.patch("openml.cli.runs_info") as mock_info: + args = argparse.Namespace(runs_subcommand="info") + cli.runs(args) + mock_info.assert_called_once_with(args) + + # Test with download subcommand + with mock.patch("openml.cli.runs_download") as mock_download: + args = argparse.Namespace(runs_subcommand="download") + cli.runs(args) + mock_download.assert_called_once_with(args) + + def test_runs_dispatcher_invalid_subcommand(self) -> None: + """Test runs command dispatcher with invalid subcommand.""" + args = argparse.Namespace(runs_subcommand="invalid") + + # Capture stderr + captured_error = StringIO() + sys.stderr = captured_error + + with pytest.raises(SystemExit): + cli.runs(args) + + sys.stderr = sys.__stderr__ + + +class TestCLIIntegration(TestBase): + """Integration tests for CLI argument parsing.""" + + def test_main_help(self) -> None: + """Test that main help displays runs command.""" + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="subroutine") + subparsers.add_parser("runs") + + # Should not raise an error + args = parser.parse_args(["runs"]) + assert args.subroutine == "runs" + + def test_runs_list_argument_parsing(self) -> None: + """Test argument parsing for runs list command.""" + # Create a minimal parser for testing + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="subroutine") + runs_parser = subparsers.add_parser("runs") + runs_subparsers = runs_parser.add_subparsers(dest="runs_subcommand") + list_parser = runs_subparsers.add_parser("list") + + list_parser.add_argument("--task", type=int) + list_parser.add_argument("--flow", type=int) + list_parser.add_argument("--uploader", type=str) + list_parser.add_argument("--tag", type=str) + list_parser.add_argument("--size", type=int, default=10) + list_parser.add_argument("--offset", type=int, default=0) + list_parser.add_argument("--format", choices=["list", "table", "json"], default="list") + list_parser.add_argument("--verbose", action="store_true") + + # Test with various arguments + args = parser.parse_args(["runs", "list", "--task", "1", "--flow", "100", "--size", "20"]) + assert args.subroutine == "runs" + assert args.runs_subcommand == "list" + assert args.task == 1 + assert args.flow == 100 + assert args.size == 20 + assert args.offset == 0 + assert args.format == "list" + assert args.verbose is False + + def test_runs_info_argument_parsing(self) -> None: + """Test argument parsing for runs info command.""" + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="subroutine") + runs_parser = subparsers.add_parser("runs") + runs_subparsers = runs_parser.add_subparsers(dest="runs_subcommand") + info_parser = runs_subparsers.add_parser("info") + info_parser.add_argument("run_id", type=int) + + args = parser.parse_args(["runs", "info", "12345"]) + assert args.subroutine == "runs" + assert args.runs_subcommand == "info" + assert args.run_id == 12345 + + def test_runs_download_argument_parsing(self) -> None: + """Test argument parsing for runs download command.""" + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="subroutine") + runs_parser = subparsers.add_parser("runs") + runs_subparsers = runs_parser.add_subparsers(dest="runs_subcommand") + download_parser = runs_subparsers.add_parser("download") + download_parser.add_argument("run_id", type=int) + + args = parser.parse_args(["runs", "download", "12345"]) + assert args.subroutine == "runs" + assert args.runs_subcommand == "download" + assert args.run_id == 12345