diff --git a/pytools/graph.py b/pytools/graph.py index 7b0f3423..27f50662 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -42,6 +42,7 @@ .. autoexception:: CycleError .. autofunction:: compute_topological_order .. autofunction:: compute_transitive_closure +.. autofunction:: find_cycles .. autofunction:: contains_cycle .. autofunction:: compute_induced_subgraph .. autofunction:: as_graphviz_dot @@ -68,6 +69,8 @@ Mapping, MutableSet, Optional, Set, Tuple, TypeVar) +from enum import Enum + try: from typing import TypeAlias except ImportError: @@ -242,6 +245,52 @@ def __init__(self, node: NodeT) -> None: self.node = node +class _NodeState(Enum): + WHITE = 0 # Not visited yet + GREY = 1 # Currently visiting + BLACK = 2 # Done visiting + + +def find_cycles(graph: GraphT, all_cycles: bool = True) -> List[List[NodeT]]: + """ + Find cycles in *graph* using DFS. + + :arg all_cycles: If False, only return the first cycle found. + + :returns: A :class:`list` in which each element represents another :class:`list` + of nodes that form a cycle. + """ + def dfs(node: NodeT, path: List[NodeT]) -> List[NodeT]: + # Cycle detected + if visited[node] == _NodeState.GREY: + return path + [node] + + # Visit this node, explore its children + visited[node] = _NodeState.GREY + for child in graph[node]: + if visited[child] != _NodeState.BLACK and dfs(child, path): + return path + [node] + ( + [child] if child != node else []) + + # Done visiting node + visited[node] = _NodeState.BLACK + return [] + + visited = {node: _NodeState.WHITE for node in graph.keys()} + + res = [] + + for node in graph: + if visited[node] == _NodeState.WHITE: + cycle = dfs(node, []) + if cycle: + res.append(cycle) + if not all_cycles: + return res + + return res + + class HeapEntry: """ Helper class to compare associated keys while comparing the elements in @@ -259,14 +308,17 @@ def __lt__(self, other: "HeapEntry") -> bool: def compute_topological_order(graph: GraphT[NodeT], - key: Optional[Callable[[NodeT], Any]] = None) \ - -> List[NodeT]: + key: Optional[Callable[[NodeT], Any]] = None, + verbose_cycle: bool = True) -> List[NodeT]: """Compute a topological order of nodes in a directed graph. :arg key: A custom key function may be supplied to determine the order in break-even cases. Expects a function of one argument that is used to extract a comparison key from each node of the *graph*. + :arg verbose_cycle: Verbose reporting in case *graph* contains a cycle, i.e. + return a :class:`CycleError` which has a node that is part of a cycle. + :returns: A :class:`list` representing a valid topological ordering of the nodes in the directed graph. @@ -318,9 +370,17 @@ def compute_topological_order(graph: GraphT[NodeT], heappush(heap, HeapEntry(child, keyfunc(child))) if len(order) != total_num_nodes: - # any node which has a predecessor left is a part of a cycle - raise CycleError(next(iter(n for n, num_preds in - nodes_to_num_predecessors.items() if num_preds != 0))) + # There is a cycle in the graph + if not verbose_cycle: + raise CycleError(None) + + try: + cycles: List[List[NodeT]] = find_cycles(graph) + except KeyError: + # Graph is invalid + raise CycleError(None) + else: + raise CycleError(cycles[0][0]) return order @@ -373,11 +433,7 @@ def contains_cycle(graph: GraphT[NodeT]) -> bool: .. versionadded:: 2020.2 """ - try: - compute_topological_order(graph) - return False - except CycleError: - return True + return bool(find_cycles(graph, all_cycles=False)) # }}} diff --git a/test/test_graph_tools.py b/test/test_graph_tools.py index 688c6e4c..e95564c6 100644 --- a/test/test_graph_tools.py +++ b/test/test_graph_tools.py @@ -431,6 +431,38 @@ def test_is_connected(): assert is_connected({}) +def test_find_cycles(): + from pytools.graph import compute_topological_order, CycleError, find_cycles + + # Non-Self Loop + graph = {1: {}, 5: {1, 8}, 8: {5}} + assert find_cycles(graph) == [[5, 8]] + with pytest.raises(CycleError, match="5|8"): + compute_topological_order(graph) + + # Self-Loop + graph = {1: {1}, 5: {8}, 8: {}} + assert find_cycles(graph) == [[1]] + with pytest.raises(CycleError, match="1"): + compute_topological_order(graph) + + # Invalid graph with loop + graph = {1: {42}, 5: {8}, 8: {5}} + # Can't run find_cycles on this graph since it is invalid + with pytest.raises(CycleError, match="None"): + compute_topological_order(graph) + + # Multiple loops + graph = {1: {1}, 5: {8}, 8: {5}} + assert find_cycles(graph) == [[1], [5, 8]] + with pytest.raises(CycleError, match="1"): + compute_topological_order(graph) + + # Cycle over multiple nodes + graph = {4: {2}, 2: {3}, 3: {4}} + assert find_cycles(graph) == [[4, 2, 3]] + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])