Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 163 additions & 2 deletions openml/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Callable
from urllib.parse import urlparse

from openml import config
from openml import config, tasks


def is_hex(string_: str) -> bool:
Expand Down Expand Up @@ -299,6 +299,109 @@ def configure_field( # noqa: PLR0913
verbose_set(field, value)


def tasks_list(args: argparse.Namespace) -> None:
"""List tasks with optional filtering."""
# Build filter parameters
kwargs = {}
if args.tag:
kwargs["tag"] = args.tag
if args.task_type:
kwargs["task_type"] = args.task_type
if args.status:
kwargs["status"] = args.status
if args.data_name:
kwargs["data_name"] = args.data_name

# Fetch tasks
try:
tasks_df = tasks.list_tasks(
offset=args.offset,
size=args.size,
**kwargs,
)
# Simple table output (reuse _format_output if it exists, otherwise simple print)
if tasks_df.empty:
print("No results found.")
elif args.format == "json":
print(tasks_df.to_json(orient="records", indent=2))
else:
print(tasks_df.to_string(index=False))
except Exception as e: # noqa: BLE001
print(f"Error listing tasks: {e}", file=sys.stderr)
sys.exit(1)


def tasks_info(args: argparse.Namespace) -> None:
"""Display detailed information about a specific task."""
try:
task = tasks.get_task(int(args.task_id))

# Print basic information
print(f"Task ID: {task.task_id}")
print(f"Task Type: {task.task_type}")
print(f"Dataset ID: {task.dataset_id}")
print(f"Estimation Procedure: {task.estimation_procedure.get('type', 'N/A')}")
if task.evaluation_measure:
print(f"Evaluation Measure: {task.evaluation_measure}")

# Print target feature if available
if hasattr(task, "target_name") and task.target_name:
print(f"Target Feature: {task.target_name}")

# Print class labels if available
if hasattr(task, "class_labels") and task.class_labels:
print(f"Number of Classes: {len(task.class_labels)}")
print(f"Class Labels: {', '.join(map(str, task.class_labels[:10]))}")
if len(task.class_labels) > 10:
print(f" ... and {len(task.class_labels) - 10} more")

except Exception as e: # noqa: BLE001
print(f"Error fetching task info: {e}", file=sys.stderr)
sys.exit(1)


def tasks_search(args: argparse.Namespace) -> None:
"""Search tasks by associated dataset name (case-insensitive)."""
try:
# Search by dataset name using the API
tasks_df = tasks.list_tasks(data_name=args.query, size=args.size or 100)

# If no exact match, do case-insensitive client-side filtering
if tasks_df.empty:
all_tasks = tasks.list_tasks(size=1000) # Get more tasks
# Filter by dataset name if available in the dataframe
if "data_name" in all_tasks.columns:
mask = all_tasks["data_name"].str.contains(args.query, case=False, na=False)
tasks_df = all_tasks[mask].head(args.size or 20)

if tasks_df.empty:
print(f"No tasks found for dataset matching '{args.query}'")
else:
print(f"Found {len(tasks_df)} task(s) for dataset matching '{args.query}':\n")
if args.format == "json":
print(tasks_df.to_json(orient="records", indent=2))
else:
print(tasks_df.to_string(index=False))
except Exception as e: # noqa: BLE001
print(f"Error searching tasks: {e}", file=sys.stderr)
sys.exit(1)


def tasks_handler(args: argparse.Namespace) -> None:
"""Route tasks subcommands to appropriate handlers."""
actions = {
"list": tasks_list,
"info": tasks_info,
"search": tasks_search,
}
action_func = actions.get(args.tasks_action)
if action_func:
action_func(args)
else:
print("Please specify a tasks action: list, info, or search")
sys.exit(1)


def configure(args: argparse.Namespace) -> None:
"""Calls the right submenu(s) to edit `args.field` in the configuration file."""
set_functions = {
Expand Down Expand Up @@ -328,7 +431,7 @@ def not_supported_yet(_: str) -> None:


def main() -> None:
subroutines = {"configure": configure}
subroutines = {"configure": configure, "tasks": tasks_handler}

parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="subroutine")
Expand Down Expand Up @@ -360,6 +463,64 @@ def main() -> None:
help="The value to set the FIELD to.",
)

# Tasks subparser (NEW)
parser_tasks = subparsers.add_parser(
"tasks",
description="Browse and search OpenML tasks from the command line.",
)
tasks_subparsers = parser_tasks.add_subparsers(dest="tasks_action")

# tasks list
parser_tasks_list = tasks_subparsers.add_parser(
"list",
help="List tasks with optional filtering",
)
parser_tasks_list.add_argument("--offset", type=int, help="Number of tasks to skip")
parser_tasks_list.add_argument("--size", type=int, help="Maximum number of tasks to show")
parser_tasks_list.add_argument("--tag", type=str, help="Filter by tag")
parser_tasks_list.add_argument("--task-type", type=str, help="Filter by task type")
parser_tasks_list.add_argument("--status", type=str, help="Filter by status")
parser_tasks_list.add_argument("--data-name", type=str, help="Filter by dataset name")
parser_tasks_list.add_argument(
"--format",
type=str,
choices=["table", "json"],
default="table",
help="Output format",
)
parser_tasks_list.add_argument(
"--verbose",
action="store_true",
help="Show all columns",
)

# tasks info
parser_tasks_info = tasks_subparsers.add_parser(
"info",
help="Display detailed information about a specific task",
)
parser_tasks_info.add_argument("task_id", type=str, help="Task ID")

# tasks search
parser_tasks_search = tasks_subparsers.add_parser(
"search",
help="Search tasks by dataset name (case-insensitive)",
)
parser_tasks_search.add_argument("query", type=str, help="Dataset name to search for")
parser_tasks_search.add_argument("--size", type=int, help="Maximum number of results")
parser_tasks_search.add_argument(
"--format",
type=str,
choices=["table", "json"],
default="table",
help="Output format",
)
parser_tasks_search.add_argument(
"--verbose",
action="store_true",
help="Show all columns",
)

args = parser.parse_args()
subroutines.get(args.subroutine, lambda _: parser.print_help())(args)

Expand Down
Loading