diff --git a/openml/flows/__init__.py b/openml/flows/__init__.py index ce32fec7d..5e14f9511 100644 --- a/openml/flows/__init__.py +++ b/openml/flows/__init__.py @@ -4,6 +4,7 @@ from .functions import ( assert_flows_equal, delete_flow, + edit_flow, flow_exists, get_flow, get_flow_id, @@ -18,4 +19,5 @@ "flow_exists", "assert_flows_equal", "delete_flow", + "edit_flow", ] diff --git a/openml/flows/functions.py b/openml/flows/functions.py index 9906958e5..6e6a4cb6b 100644 --- a/openml/flows/functions.py +++ b/openml/flows/functions.py @@ -552,3 +552,113 @@ def delete_flow(flow_id: int) -> bool: True if the deletion was successful. False otherwise. """ return openml.utils._delete_entity("flow", flow_id) + + +def edit_flow( + flow_id: int, + custom_name: str | None = None, + tags: list[str] | None = None, + language: str | None = None, + description: str | None = None, +) -> int: + """Edits an OpenMLFlow. + + In addition to providing the flow id of the flow to edit (through flow_id), + you must specify a value for at least one of the optional function arguments, + i.e. one value for a field to edit. + + This function allows editing of non-critical fields only. + Editable fields are: custom_name, tags, language, description. + + Editing is allowed only for the owner of the flow. + + Parameters + ---------- + flow_id : int + ID of the flow. + custom_name : str, optional + Custom name for the flow. + tags : list[str], optional + Tags to associate with the flow. + language : str, optional + Language in which the flow is described. + Starts with 1 upper case letter, rest lower case, e.g. 'English'. + description : str, optional + Human-readable description of the flow. + + Returns + ------- + flow_id : int + The ID of the edited flow. + + Raises + ------ + TypeError + If flow_id is not an integer. + ValueError + If no fields are provided for editing. + OpenMLServerException + If the user is not authorized to edit the flow or if the flow doesn't exist. + + Examples + -------- + >>> import openml + >>> # Edit the custom name of a flow + >>> edited_flow_id = openml.flows.edit_flow(123, custom_name="My Custom Flow Name") + >>> + >>> # Edit multiple fields at once + >>> edited_flow_id = openml.flows.edit_flow( + ... 456, + ... custom_name="Updated Flow", + ... language="English", + ... description="An updated description for this flow", + ... tags=["machine-learning", "classification"] + ... ) + """ + if not isinstance(flow_id, int): + raise TypeError(f"`flow_id` must be of type `int`, not {type(flow_id)}.") + + # Check if at least one field is provided for editing + fields_to_edit = [custom_name, tags, language, description] + if all(field is None for field in fields_to_edit): + raise ValueError( + "At least one field must be provided for editing. " + "Available fields: custom_name, tags, language, description" + ) + + # Compose flow edit parameters as XML + form_data = {"flow_id": flow_id} # type: openml._api_calls.DATA_TYPE + xml = OrderedDict() # type: 'OrderedDict[str, OrderedDict]' + xml["oml:flow_edit_parameters"] = OrderedDict() + xml["oml:flow_edit_parameters"]["@xmlns:oml"] = "http://openml.org/openml" + xml["oml:flow_edit_parameters"]["oml:custom_name"] = custom_name + xml["oml:flow_edit_parameters"]["oml:language"] = language + xml["oml:flow_edit_parameters"]["oml:description"] = description + + # Handle tags - convert list to comma-separated string if provided + if tags is not None: + if isinstance(tags, list): + xml["oml:flow_edit_parameters"]["oml:tag"] = ",".join(tags) + else: + xml["oml:flow_edit_parameters"]["oml:tag"] = str(tags) + else: + xml["oml:flow_edit_parameters"]["oml:tag"] = None + + # Remove None values from XML + for key in list(xml["oml:flow_edit_parameters"]): + if not xml["oml:flow_edit_parameters"][key]: + del xml["oml:flow_edit_parameters"][key] + + file_elements = { + "edit_parameters": ("description.xml", xmltodict.unparse(xml)), + } # type: openml._api_calls.FILE_ELEMENTS_TYPE + + result_xml = openml._api_calls._perform_api_call( + "flow/edit", + "post", + data=form_data, + file_elements=file_elements, + ) + result = xmltodict.parse(result_xml) + edited_flow_id = result["oml:flow_edit"]["oml:id"] + return int(edited_flow_id) diff --git a/tests/test_flows/test_flow_functions.py b/tests/test_flows/test_flow_functions.py index ef4759e54..7b1d3b669 100644 --- a/tests/test_flows/test_flow_functions.py +++ b/tests/test_flows/test_flow_functions.py @@ -537,3 +537,87 @@ def test_delete_unknown_flow(mock_delete, test_files_directory, test_api_key): flow_url = "https://test.openml.org/api/v1/xml/flow/9999999" assert flow_url == mock_delete.call_args.args[0] assert test_api_key == mock_delete.call_args.kwargs.get("params", {}).get("api_key") + + +@mock.patch.object(openml._api_calls, "_perform_api_call") +def test_edit_flow_custom_name(mock_api_call): + """Test edit_flow with custom_name field.""" + # Mock the API response + mock_api_call.return_value = '123' + + result = openml.flows.edit_flow(123, custom_name="New Custom Name") + + # Check that the function returns the correct flow ID + assert result == 123 + + # Verify the API call was made with correct parameters + mock_api_call.assert_called_once() + call_args = mock_api_call.call_args + assert call_args[0][0] == "flow/edit" # endpoint + assert call_args[0][1] == "post" # method + assert call_args[1]["data"]["flow_id"] == 123 + + +@mock.patch.object(openml._api_calls, "_perform_api_call") +def test_edit_flow_multiple_fields(mock_api_call): + """Test edit_flow with multiple fields.""" + # Mock the API response + mock_api_call.return_value = '456' + + result = openml.flows.edit_flow( + 456, + custom_name="Updated Name", + language="English", + description="Updated description", + tags=["tag1", "tag2"] + ) + + # Check that the function returns the correct flow ID + assert result == 456 + + # Verify the API call was made + mock_api_call.assert_called_once() + call_args = mock_api_call.call_args + assert call_args[0][0] == "flow/edit" + assert call_args[0][1] == "post" + assert call_args[1]["data"]["flow_id"] == 456 + + +@mock.patch.object(openml._api_calls, "_perform_api_call") +def test_edit_flow_tags_as_list(mock_api_call): + """Test edit_flow with tags provided as a list.""" + # Mock the API response + mock_api_call.return_value = '789' + + result = openml.flows.edit_flow(789, tags=["machine-learning", "sklearn"]) + + # Check that the function returns the correct flow ID + assert result == 789 + + # Verify the API call was made + mock_api_call.assert_called_once() + + +@mock.patch.object(openml._api_calls, "_perform_api_call") +def test_edit_flow_server_error(mock_api_call): + """Test edit_flow when server returns an error.""" + from openml.exceptions import OpenMLServerException + + # Mock a server error + mock_api_call.side_effect = OpenMLServerException("Flow not found") + + with pytest.raises(OpenMLServerException, match="Flow not found"): + openml.flows.edit_flow(999, custom_name="Test") + + def test_edit_flow_invalid_flow_id(self): + """Test that edit_flow raises TypeError for non-integer flow_id.""" + with pytest.raises(TypeError, match="`flow_id` must be of type `int`"): + openml.flows.edit_flow("not_an_int", custom_name="test") + + def test_edit_flow_no_fields(self): + """Test that edit_flow raises ValueError when no fields are provided.""" + with pytest.raises( + ValueError, + match="At least one field must be provided for editing" + ): + openml.flows.edit_flow(1)