diff --git a/lib/compiler/include/compiler/algorithm_config.variant.toml b/lib/compiler/include/compiler/algorithm_config.variant.toml new file mode 100644 index 0000000000..4e58104875 --- /dev/null +++ b/lib/compiler/include/compiler/algorithm_config.variant.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "AlgorithmConfig" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/data_parallelism/data_parallelism_config.dtg.h", + "compiler/unity_algorithm/unity_search_config.dtg.h", +] + +[[values]] +type = "::FlexFlow::DataParallelismConfig" + +[[values]] +type = "::FlexFlow::UnitySearchConfig" diff --git a/lib/compiler/include/compiler/compiler.h b/lib/compiler/include/compiler/compiler.h index 178ab19a53..8697c06beb 100644 --- a/lib/compiler/include/compiler/compiler.h +++ b/lib/compiler/include/compiler/compiler.h @@ -1,42 +1,22 @@ #ifndef _FLEXFLOW_COMPILER_COMPILER_H #define _FLEXFLOW_COMPILER_COMPILER_H -#include "pcg/cost_values.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "pcg/tensor_mapping.h" +#include "compiler/algorithm_config.dtg.h" +#include "compiler/cost_estimator/cost_estimator.h" +#include "compiler/search_result.dtg.h" +#include "pcg/machine_specification.dtg.h" namespace FlexFlow { enum class SearchAlgorithm { DATA_PARALLEL, -}; - -using SearchAlgorithmConfig = std::variant<>; -using SearchSolution = std::variant<>; - -struct SearchResult { - ParallelComputationGraph pcg; - TensorMapping tensor_mapping; - SearchSolution solution; - CostValues cost_values; + UNITY, }; SearchResult optimize(ComputationGraph const &, MachineSpecification const &, CostEstimator const &, - SearchAlgorithm, - optional const &); - -// struct SearchSolution { -// LabelledMultiDiGraph optimized_pcg; -// std::unordered_map device_assignments; -// /* std::unordered_map> tensor_mappings; */ -// }; -// -// SearchSolution run_data_parallelize(ComputationGraph const &, -// MachineSpecification const &); + AlgorithmConfig const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/data_parallelism/data_parallelism_config.struct.toml b/lib/compiler/include/compiler/data_parallelism/data_parallelism_config.struct.toml new file mode 100644 index 0000000000..68512fa473 --- /dev/null +++ b/lib/compiler/include/compiler/data_parallelism/data_parallelism_config.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "DataParallelismConfig" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ +] + +[[fields]] +name = "degree" +type = "int" diff --git a/lib/compiler/include/compiler/graph_optimize_result.struct.toml b/lib/compiler/include/compiler/graph_optimize_result.struct.toml deleted file mode 100644 index 22f29cbd59..0000000000 --- a/lib/compiler/include/compiler/graph_optimize_result.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "GraphOptimizeResult" -features = [ ] - -includes = [ - "compiler/machine_mapping/machine_mapping.dtg.h", - "pcg/parallel_computation_graph/parallel_computation_graph.h" -] - -[[fields]] -name = "pcg" -type = "::FlexFlow::ParallelComputationGraph" - -[[fields]] -name = "machine_mapping" -type = "::FlexFlow::MachineMapping" diff --git a/lib/compiler/include/compiler/allowed_machine_views.h b/lib/compiler/include/compiler/machine_mapping/allowed_machine_views.h similarity index 100% rename from lib/compiler/include/compiler/allowed_machine_views.h rename to lib/compiler/include/compiler/machine_mapping/allowed_machine_views.h diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h index 7375cde985..796225637e 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h @@ -2,6 +2,8 @@ #define _FLEXFLOW_COMPILER_MACHINE_MAPPING_H #include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" #include "pcg/device_id_t.dtg.h" #include "pcg/machine_specification.dtg.h" #include "pcg/operator_task_space.dtg.h" @@ -14,6 +16,13 @@ MachineMapping combine_disjoint_mappings(MachineMapping const &, bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); +parallel_layer_guid_t + get_layer_from_path(PCGBinarySPDecomposition const &sp_decomposition, + BinaryTreePath const &path); + +std::optional get_machine_mapping_from_machine_mapping_result( + PCGBinarySPDecomposition const &, MachineMappingResult const &); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h index 68d02aaa54..168ba6c3d5 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h @@ -9,6 +9,9 @@ namespace FlexFlow { +bool is_valid_machine_mapping_problem_tree( + MachineMappingProblemTree const &problem_tree); + MachineMappingProblemTree get_machine_mapping_problem_tree(ParallelComputationGraph const &pcg, PCGBinarySPDecomposition const &sp); diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h index 29e9e7c90b..3d1dc91d24 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h @@ -4,6 +4,7 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h" #include "utils/full_binary_tree/binary_tree_path.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" @@ -27,6 +28,9 @@ std::optional mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &, BinaryTreePath const &); +std::string as_dot(MachineMappingProblemTree const &); +void debug_print_dot(MachineMappingProblemTree const &); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml index fe76683eb7..7493c68387 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml @@ -11,6 +11,7 @@ includes = [ "op-attrs/parallel_tensor_shape.dtg.h", "", "pcg/machine_view.dtg.h", + "pcg/operator_task_space.dtg.h", ] src_includes = [ @@ -34,3 +35,6 @@ type = "std::vector<::FlexFlow::ParallelTensorShape>" name = "output_shapes" type = "std::vector<::FlexFlow::ParallelTensorShape>" +[[fields]] +name = "op_task_space" +type = "::FlexFlow::OperatorTaskSpace" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h index b21fea5f24..db2f4e6f0d 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h @@ -31,6 +31,8 @@ FeasibleMachineMappingResult require_feasible(MachineMappingResult const &); make_singleton_machine_mapping_result(float runtime, MachineView const &machine_view); +[[nodiscard]] float get_runtime_cost(MachineMappingResult const &mm_result); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/graph_optimize_result.h b/lib/compiler/include/compiler/search_result.h similarity index 54% rename from lib/compiler/include/compiler/graph_optimize_result.h rename to lib/compiler/include/compiler/search_result.h index f3843e2a93..197b36e9ea 100644 --- a/lib/compiler/include/compiler/graph_optimize_result.h +++ b/lib/compiler/include/compiler/search_result.h @@ -1,12 +1,12 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_GRAPH_OPTIMIZE_RESULT_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_GRAPH_OPTIMIZE_RESULT_H -#include "compiler/graph_optimize_result.dtg.h" +#include "compiler/search_result.dtg.h" namespace FlexFlow { -std::string format_as(GraphOptimizeResult const &); -std::ostream &operator<<(std::ostream &, GraphOptimizeResult const &); +std::string format_as(SearchResult const &); +std::ostream &operator<<(std::ostream &, SearchResult const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/search_result.struct.toml b/lib/compiler/include/compiler/search_result.struct.toml new file mode 100644 index 0000000000..120d182c75 --- /dev/null +++ b/lib/compiler/include/compiler/search_result.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "SearchResult" +features = [ +] + +includes = [ + "pcg/parallel_computation_graph/parallel_computation_graph.h", + "compiler/machine_mapping/machine_mapping.h", +] + +[[fields]] +name = "pcg" +type = "::FlexFlow::ParallelComputationGraph" + +[[fields]] +name = "machine_mapping" +type = "::FlexFlow::MachineMapping" diff --git a/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h index d43edaa79d..bb7459c767 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h @@ -1,6 +1,8 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" + namespace FlexFlow { std::optional diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h index 86fa1a59aa..e4fd841787 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h @@ -27,6 +27,10 @@ std::optional std::unordered_multiset get_parallel_layers(PCGBinarySPDecomposition const &); +PCGBinarySPDecomposition + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( + BinarySPDecompositionTree const &); + SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &); std::unordered_set diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h deleted file mode 100644 index 232f2b9563..0000000000 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H -#define _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H - -#include "compiler/cost_estimator/cost_estimator.h" -#include "compiler/graph_optimize_result.dtg.h" -#include "optimizer_config.dtg.h" -#include "pcg/computation_graph.h" -#include "pcg/machine_specification.dtg.h" -#include "substitutions/sub_parallel_computation_graph.h" - -namespace FlexFlow { - -GraphOptimizeResult graph_optimize( - ParallelComputationGraph &pcg, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - OptimizerConfig const &opt_config); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/graph_optimize_state.h b/lib/compiler/include/compiler/unity_algorithm/graph_optimize_state.h similarity index 74% rename from lib/compiler/include/compiler/graph_optimize_state.h rename to lib/compiler/include/compiler/unity_algorithm/graph_optimize_state.h index 404111ff8b..5f06fd242c 100644 --- a/lib/compiler/include/compiler/graph_optimize_state.h +++ b/lib/compiler/include/compiler/unity_algorithm/graph_optimize_state.h @@ -1,16 +1,17 @@ #ifndef _FLEXFLOW_COMPILER_MCMC_STATE_H #define _FLEXFLOW_COMPILER_MCMC_STATE_H -#include "compiler/graph_optimize_result.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" namespace FlexFlow { struct GraphOptimizeState { - explicit GraphOptimizeState(GraphOptimizeResult const &graph_optimize_result, + GraphOptimizeState() = delete; + explicit GraphOptimizeState(ParallelComputationGraph const &pcg, float runtime); - GraphOptimizeResult graph_optimize_result; - float runtime; + ParallelComputationGraph pcg; + float runtime_with_optimal_mm; bool operator==(GraphOptimizeState const &other) const; bool operator!=(GraphOptimizeState const &other) const; diff --git a/lib/compiler/include/compiler/unity_algorithm/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm/unity_algorithm.h new file mode 100644 index 0000000000..618e764f80 --- /dev/null +++ b/lib/compiler/include/compiler/unity_algorithm/unity_algorithm.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H +#define _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H + +#include "compiler/cost_estimator/cost_estimator.h" +#include "compiler/search_result.dtg.h" +#include "compiler/unity_algorithm/unity_search_config.dtg.h" +#include "pcg/machine_specification.dtg.h" +#include "substitutions/substitution.h" + +namespace FlexFlow { + +SearchResult graph_optimize(ParallelComputationGraph &pcg, + CostEstimator const &cost_estimator, + MachineSpecification const &resources, + UnitySearchConfig const &search_config); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/optimizer_config.struct.toml b/lib/compiler/include/compiler/unity_algorithm/unity_search_config.struct.toml similarity index 90% rename from lib/compiler/include/compiler/optimizer_config.struct.toml rename to lib/compiler/include/compiler/unity_algorithm/unity_search_config.struct.toml index b7f4f71e9c..9ec22cf916 100644 --- a/lib/compiler/include/compiler/optimizer_config.struct.toml +++ b/lib/compiler/include/compiler/unity_algorithm/unity_search_config.struct.toml @@ -1,5 +1,5 @@ namespace = "FlexFlow" -name = "OptimizerConfig" +name = "UnitySearchConfig" features = [ "eq", "hash", diff --git a/lib/compiler/src/compiler/compiler.cc b/lib/compiler/src/compiler/compiler.cc new file mode 100644 index 0000000000..a58651f01a --- /dev/null +++ b/lib/compiler/src/compiler/compiler.cc @@ -0,0 +1,26 @@ +#include "compiler/compiler.h" +#include "compiler/unity_algorithm/unity_algorithm.h" +#include "pcg/pcg_from_computation_graph.h" +#include "utils/overload.h" + +namespace FlexFlow { + +SearchResult optimize(ComputationGraph const &computation_graph, + MachineSpecification const &machine_specification, + CostEstimator const &cost_estimator, + AlgorithmConfig const &search_config) { + return search_config.visit(overload{ + [&](DataParallelismConfig const &config) -> SearchResult { + throw std::runtime_error( + "Data parallel search algorithm is not implemented yet"); + }, + [&](UnitySearchConfig const &config) { + ParallelComputationGraph pcg = + pcg_from_computation_graph(computation_graph); + return graph_optimize( + pcg, cost_estimator, machine_specification, config); + }, + }); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/graph_optimize_state.cc b/lib/compiler/src/compiler/graph_optimize_state.cc deleted file mode 100644 index 1091b92866..0000000000 --- a/lib/compiler/src/compiler/graph_optimize_state.cc +++ /dev/null @@ -1,96 +0,0 @@ -#include "compiler/graph_optimize_state.h" -#include "compiler/graph_optimize_result.h" -#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" - -namespace FlexFlow { - -GraphOptimizeState::GraphOptimizeState( - GraphOptimizeResult const &graph_optimize_result, float runtime) - : graph_optimize_result(graph_optimize_result), runtime(runtime) {} - -bool GraphOptimizeState::operator==(GraphOptimizeState const &other) const { - // Note(@wmdi): This is a hack to implement a partially correct homomorphism - // check. Switch to the homomorphism check used in substitutions right after - // https://github.com/flexflow/FlexFlow/pull/1471 is merged. - auto layers1 = topological_ordering(graph_optimize_result.pcg); - auto layers2 = topological_ordering(other.graph_optimize_result.pcg); - if (layers1.size() != layers2.size()) { - return false; - } - std::unordered_map mapping; - for (size_t i = 0; i < layers1.size(); ++i) { - if (get_parallel_layer_attrs(graph_optimize_result.pcg, layers1[i]) != - get_parallel_layer_attrs(other.graph_optimize_result.pcg, layers2[i])) { - return false; - } - auto inputs1 = get_incoming_tensors(graph_optimize_result.pcg, layers1[i]); - auto inputs2 = - get_incoming_tensors(other.graph_optimize_result.pcg, layers2[i]); - if (inputs1.size() != inputs2.size()) { - return false; - } - for (size_t j = 0; j < inputs1.size(); ++j) { - if (inputs1[j] != mapping.at(inputs2[j])) { - return false; - } - } - auto outputs1 = get_layer_outputs(graph_optimize_result.pcg, layers1[i]); - auto outputs2 = - get_layer_outputs(other.graph_optimize_result.pcg, layers2[i]); - if (outputs1.size() != outputs2.size()) { - return false; - } - for (size_t j = 0; j < outputs1.size(); ++j) { - mapping.emplace(outputs2[j], outputs1[j]); - } - } - return true; -} - -bool GraphOptimizeState::operator!=(GraphOptimizeState const &other) const { - return !(*this == other); -} - -bool GraphOptimizeState::operator<(GraphOptimizeState const &other) const { - return runtime < other.runtime; -} - -std::string format_as(GraphOptimizeState const &st) { - return fmt::format("", - st.graph_optimize_result, - st.runtime); -} - -std::ostream &operator<<(std::ostream &s, GraphOptimizeState const &st) { - return (s << fmt::to_string(st)); -} - -} // namespace FlexFlow - -namespace std { - -size_t hash<::FlexFlow::GraphOptimizeState>::operator()( - ::FlexFlow::GraphOptimizeState const &state) const { - // TODO(@wmdi): Eventually it might be good to use a proper graph hash like - // https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash.html#networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash - size_t seed = 0; - auto layers = topological_ordering(state.graph_optimize_result.pcg); - ::FlexFlow::hash_combine(seed, layers.size()); - for (auto layer : layers) { - ::FlexFlow::hash_combine( - seed, get_parallel_layer_attrs(state.graph_optimize_result.pcg, layer)); - auto inputs = get_incoming_tensors(state.graph_optimize_result.pcg, layer); - ::FlexFlow::hash_combine(seed, inputs.size()); - for (auto input : inputs) { - for (size_t i = 0; i < layers.size(); ++i) { - if (get_source_layer(input) == layers[i]) { - ::FlexFlow::hash_combine(seed, i); - break; - } - } - } - } - return seed; -} - -} // namespace std diff --git a/lib/compiler/src/compiler/allowed_machine_views.cc b/lib/compiler/src/compiler/machine_mapping/allowed_machine_views.cc similarity index 79% rename from lib/compiler/src/compiler/allowed_machine_views.cc rename to lib/compiler/src/compiler/machine_mapping/allowed_machine_views.cc index 6f86d1d82a..b4df1451ca 100644 --- a/lib/compiler/src/compiler/allowed_machine_views.cc +++ b/lib/compiler/src/compiler/machine_mapping/allowed_machine_views.cc @@ -1,4 +1,4 @@ -#include "compiler/allowed_machine_views.h" +#include "compiler/machine_mapping/allowed_machine_views.h" #include "pcg/machine_specification.h" #include "pcg/machine_view.h" #include "pcg/multi_dimensional_stride.dtg.h" @@ -57,6 +57,8 @@ static std::unordered_set product(transform(tensor_dims, [](nonnegative_int num_devices) { return nonnegative_int{num_devices.unwrap_nonnegative() - 1}; })); + min_num_devices_with_full_stride_volume = + std::max(min_num_devices_with_full_stride_volume, 1_n); return ceildiv(total_devices, min_num_devices_with_full_stride_volume); }; @@ -66,13 +68,19 @@ static std::unordered_set nonnegative_int max_stride_upper_bound = get_max_stride_upper_bound(tensor_dims, total_devices); - std::vector single_stride_range = - transform(nonnegative_range(1_n, max_stride_upper_bound + 1_n), - [](nonnegative_int stride) { return stride_t{stride}; }); + std::vector> stride_options = + transform(tensor_dims, [&](nonnegative_int dim_size) { + if (dim_size != 1_n) { + return transform( + nonnegative_range(1_n, max_stride_upper_bound + 1_n), + [](nonnegative_int stride) { return stride_t{stride}; }); + } else { + return std::vector{stride_t{1_n}}; + } + }); + std::unordered_multiset> raw_stride_vectors = - cartesian_product( - repeat_element(/*num_times=*/num_elements(tensor_dims), - /*element=*/single_stride_range)); + cartesian_product(stride_options); std::unordered_multiset strides = transform(raw_stride_vectors, [](auto const &stride_vec) { return MultiDimensionalStride{stride_vec}; @@ -94,10 +102,18 @@ static std::unordered_set }; auto candidate_dimensions = [](OperatorTaskSpace const &task) { - std::unordered_set options = { - MachineSpecificationDimension::INTER_NODE, - MachineSpecificationDimension::INTRA_NODE}; - return get_all_permutations_with_repetition(options, num_dims(task)); + std::vector> dimension_options = + transform(task.degrees, [](nonnegative_int dim_size) { + if (dim_size == 1_n) { + return std::vector{ + MachineSpecificationDimension::INTRA_NODE}; + } else { + return std::vector{ + MachineSpecificationDimension::INTER_NODE, + MachineSpecificationDimension::INTRA_NODE}; + } + }); + return cartesian_product(dimension_options); }; std::vector tensor_dims = task.degrees; diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 49d528e4ab..0743301e8f 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -16,9 +16,13 @@ #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/contains.h" +#include "utils/containers/contains_key.h" #include "utils/containers/flatmap.h" #include "utils/containers/generate_map.h" #include "utils/containers/get_all_assignments.h" +#include "utils/containers/keys.h" +#include "utils/containers/merge_maps.h" +#include "utils/containers/set_minus.h" #include "utils/containers/unordered_set_of.h" #include "utils/exception.h" #include "utils/overload.h" @@ -80,17 +84,23 @@ MachineMappingResult ¶llel_split_transformation) { auto get_boundary_machine_view_assignments = - [&](std::unordered_set const &boundary_layers) + [&](std::unordered_set const &boundary_layers, + MachineMappingProblemTree const &t, + BinaryTreePathEntry const &prefix) -> std::unordered_set { + std::unordered_set unconstrained_boundary_layers = + set_minus(boundary_layers, + keys(restrict_to_child(constraints, prefix).machine_views)); + std::unordered_map> allowed = generate_map( - boundary_layers, + unconstrained_boundary_layers, [&](BinaryTreePath const &l) -> std::unordered_set { + MachineMappingProblemTree subtree_at_path = + expect(mm_problem_tree_get_subtree_at_path(t, l), + "Failed to get subtree at path"); UnmappedOpCostEstimateKey leaf = - mm_problem_tree_get_subtree_at_path( - MachineMappingProblemTree{series_split}, l) - .value() - .get(); + subtree_at_path.get(); return context.allowed_machine_views(leaf, resources); }); return transform( @@ -138,24 +148,37 @@ MachineMappingResult for (ParallelLayerGuidObliviousMachineMapping const &assigned_pre_machine_views : - get_boundary_machine_view_assignments(get_src_layers(tensor_movement))) { + get_boundary_machine_view_assignments(get_src_layers(tensor_movement), + series_split.get_left_child(), + BinaryTreePathEntry::LEFT_CHILD)) { MachineMappingResult pre_result = eval_pre_boundary_mapping(assigned_pre_machine_views); + if (is_infeasible(pre_result)) { + continue; + } + for (ParallelLayerGuidObliviousMachineMapping const &assigned_post_machine_views : get_boundary_machine_view_assignments( - get_dst_layers(tensor_movement))) { + get_dst_layers(tensor_movement), + series_split.get_right_child(), + BinaryTreePathEntry::RIGHT_CHILD)) { MachineMappingResult post_result = eval_post_boundary_mapping(assigned_post_machine_views); + if (is_infeasible(post_result)) { + continue; + } + TensorSetMovement comm_across_split = concretize_abstracted_tensor_set_movement( tensor_movement, - /*pre_mapping=*/assigned_pre_machine_views, - /*post_mapping=*/assigned_post_machine_views); + /*pre_mapping=*/pre_result.raw_result.value().machine_mapping, + /*post_mapping=*/post_result.raw_result.value().machine_mapping); + float cost_across_split = context.cost_estimator.estimate_cost(comm_across_split); diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc index 82c8274808..07bde820e9 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -1,7 +1,16 @@ #include "compiler/machine_mapping/machine_mapping.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "pcg/machine_specification.h" +#include "pcg/machine_view.h" +#include "pcg/operator_task_space.dtg.h" +#include "pcg/operator_task_space.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/are_disjoint.h" #include "utils/containers/keys.h" +#include "utils/containers/map_keys.h" #include "utils/containers/merge_maps.h" +#include "utils/containers/transform.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" namespace FlexFlow { @@ -15,4 +24,39 @@ bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); } +parallel_layer_guid_t + get_layer_from_path(PCGBinarySPDecomposition const &sp_decomposition, + BinaryTreePath const &path) { + std::optional subtree_optional = + get_subtree_at_path( + sp_decomposition, generic_impl_for_pcg_sp_tree(), path); + + if (!subtree_optional.has_value()) { + throw std::runtime_error(fmt::format("Invalid tree path {}", path)); + } + + PCGBinarySPDecomposition subtree = subtree_optional.value(); + if (!subtree.is_leaf()) { + throw std::runtime_error( + fmt::format("Invalid tree path to a leaf: found {} instead", subtree)); + } + return subtree.require_leaf(); +} + +std::optional get_machine_mapping_from_machine_mapping_result( + PCGBinarySPDecomposition const &sp_decomposition, + MachineMappingResult const &mm_result) { + + return transform( + mm_result.raw_result, + [&](FeasibleMachineMappingResult const &feasible_mm_result) { + return MachineMapping{ + map_keys(feasible_mm_result.machine_mapping.raw_mapping, + [&](BinaryTreePath const &path) { + return get_layer_from_path(sp_decomposition, path); + }), + }; + }); +} + } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc index 367af3701e..1d000ff041 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -1,14 +1,50 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" #include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/all_of.h" #include "utils/overload.h" namespace FlexFlow { +bool is_valid_machine_mapping_problem_tree( + MachineMappingProblemTree const &problem_tree) { + return problem_tree.visit(overload{ + [&](MMProblemTreeSeriesSplit const &series_split) { + AbstractedTensorSetMovement tensor_movement = + series_split.tensor_set_movement; + + auto contains_paths = + [](MachineMappingProblemTree const &t, + std::unordered_set const &paths) { + return all_of(paths, [&](BinaryTreePath const &p) { + return mm_problem_tree_get_subtree_at_path(t, p).has_value(); + }); + }; + + return contains_paths(series_split.get_left_child(), + get_src_layers(tensor_movement)) && + contains_paths(series_split.get_right_child(), + get_dst_layers(tensor_movement)) && + is_valid_machine_mapping_problem_tree( + series_split.get_left_child()) && + is_valid_machine_mapping_problem_tree( + series_split.get_right_child()); + }, + [&](MMProblemTreeParallelSplit const ¶llel_split) { + return is_valid_machine_mapping_problem_tree( + parallel_split.get_left_child()) && + is_valid_machine_mapping_problem_tree( + parallel_split.get_right_child()); + }, + [&](UnmappedOpCostEstimateKey const &leaf) { return true; }, + }); +} + MachineMappingProblemTree get_machine_mapping_problem_tree( ParallelComputationGraph const &pcg, PCGBinarySPDecomposition const &sp_decomposition_tree) { @@ -23,31 +59,43 @@ MachineMappingProblemTree get_machine_mapping_problem_tree( [&](PCGBinarySeriesSplit const &series) { AbstractedTensorSetMovement tensor_movement = get_abstracted_tensor_set_movement_across_split(tr_pcg, series); - return MachineMappingProblemTree{ + MachineMappingProblemTree result = MachineMappingProblemTree{ MMProblemTreeSeriesSplit{ /*tensor_set_movement=*/tensor_movement, /*lhs=*/to_problem_tree(series.get_left_child()), /*rhs=*/to_problem_tree(series.get_right_child()), }, }; + assert(is_valid_machine_mapping_problem_tree(result)); + return result; }, [&](PCGBinaryParallelSplit const ¶llel) { - return MachineMappingProblemTree{ + MachineMappingProblemTree result = MachineMappingProblemTree{ MMProblemTreeParallelSplit{ to_problem_tree(parallel.get_left_child()), to_problem_tree(parallel.get_right_child()), }, }; + assert(is_valid_machine_mapping_problem_tree(result)); + return result; }, [&](parallel_layer_guid_t const &leaf) { - return MachineMappingProblemTree{ + MachineMappingProblemTree result = MachineMappingProblemTree{ get_unmapped_op_cost_estimate_key_for_layer(pcg, leaf), }; + assert(is_valid_machine_mapping_problem_tree(result)); + return result; }, }); }; - return to_problem_tree(sp_decomposition_tree); + MachineMappingProblemTree mm_tree = to_problem_tree(sp_decomposition_tree); + + if (!is_valid_machine_mapping_problem_tree(mm_tree)) { + throw std::runtime_error("Invalid machine mapping problem tree generated"); + } + + return mm_tree; } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc index 1e39a7be19..7834938e41 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc @@ -1,4 +1,6 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/as_dot.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" @@ -88,4 +90,54 @@ std::optional tree, generic_binary_sp_impl_for_mm_problem_tree(), path); } +std::string as_dot(MachineMappingProblemTree const &tree) { + std::function + get_series_label = + [](MMProblemTreeSeriesSplit const &series) -> std::string { + auto path_as_dot = [](BinaryTreePath const &path) -> std::string { + return "(" + + join_strings(path.entries, + ", ", + [](BinaryTreePathEntry const &entry) -> std::string { + if (entry == BinaryTreePathEntry::LEFT_CHILD) { + return "l"; + } else { + assert(entry == BinaryTreePathEntry::RIGHT_CHILD); + return "r"; + } + }) + + ")"; + }; + + auto path_set_as_dot = + [&](std::unordered_set const &path_set) -> std::string { + return "(" + join_strings(path_set, ", ", path_as_dot) + ")"; + }; + + return fmt::format( + "srcs={} dsts={}", + path_set_as_dot(get_src_layers(series.tensor_set_movement)), + path_set_as_dot(get_dst_layers(series.tensor_set_movement))); + }; + + std::function + get_parallel_label = + [](MMProblemTreeParallelSplit const ¶llel) -> std::string { + return "P"; + }; + + std::function get_leaf_label = + [](UnmappedOpCostEstimateKey const &leaf) -> std::string { return ""; }; + + return as_dot(tree, + generic_binary_sp_impl_for_mm_problem_tree(), + get_series_label, + get_parallel_label, + get_leaf_label); +} + +void debug_print_dot(MachineMappingProblemTree const &tree) { + std::cout << as_dot(tree) << std::endl; +} + } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc index 990b287f8b..b6d701cb98 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc @@ -1,4 +1,5 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "pcg/operator_task_space.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" @@ -18,6 +19,8 @@ UnmappedOpCostEstimateKey get_unmapped_op_cost_estimate_key_for_layer( transform(get_incoming_weights(pcg, layer), get_tensor_shape), /*output_shapes=*/ transform(get_layer_outputs(pcg, layer), get_tensor_shape), + /*op_task_space=*/ + get_operator_task_space(pcg, layer), }; } diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc index 3409f7f871..031b7f7fc5 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc @@ -135,4 +135,12 @@ MachineMappingResult }; } +float get_runtime_cost(MachineMappingResult const &mm_result) { + if (mm_result.raw_result == std::nullopt) { + return std::numeric_limits::infinity(); + } else { + return mm_result.raw_result.value().runtime; + } +} + } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/graph_optimize_result.cc b/lib/compiler/src/compiler/search_result.cc similarity index 58% rename from lib/compiler/src/compiler/graph_optimize_result.cc rename to lib/compiler/src/compiler/search_result.cc index f48c119603..33243a226d 100644 --- a/lib/compiler/src/compiler/graph_optimize_result.cc +++ b/lib/compiler/src/compiler/search_result.cc @@ -1,14 +1,14 @@ -#include "compiler/graph_optimize_result.h" +#include "compiler/search_result.h" namespace FlexFlow { -std::string format_as(GraphOptimizeResult const &r) { +std::string format_as(SearchResult const &r) { return fmt::format("", as_dot(r.pcg), r.machine_mapping); } -std::ostream &operator<<(std::ostream &s, GraphOptimizeResult const &r) { +std::ostream &operator<<(std::ostream &s, SearchResult const &r) { return (s << fmt::to_string(r)); } diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc index 5eb993c6ef..7b4670c608 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc @@ -1,7 +1,10 @@ #include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h" +#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.h" #include "compiler/series_parallel/pcg/pcg_binary_series_split.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" #include "utils/overload.h" namespace FlexFlow { @@ -82,8 +85,63 @@ BinarySPDecompositionTree } std::optional - get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &) { - NOT_IMPLEMENTED(); + get_pcg_balanced_binary_sp_decomposition( + ParallelComputationGraph const &pcg) { + SeriesParallelDecomposition sp_decomp = + expect(get_pcg_series_parallel_decomposition(pcg), + "Failed to get SP decomposition of PCG"); + BinarySPDecompositionTree binary_sp_tree = + left_associative_binary_sp_tree_from_nary(sp_decomp); + return pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( + binary_sp_tree); +} + +PCGBinarySeriesSplit pcg_binary_series_split_from_binary_series_split( + BinarySeriesSplit const &split) { + return PCGBinarySeriesSplit{ + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( + split.get_left_child()), + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( + split.get_right_child()), + }; +} + +PCGBinaryParallelSplit pcg_binary_parallel_split_from_binary_parallel_split( + BinaryParallelSplit const &split) { + return PCGBinaryParallelSplit{ + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( + split.get_left_child()), + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( + split.get_right_child()), + }; +} + +PCGBinarySPDecomposition + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( + BinarySPDecompositionTree const &sp_tree) { + + return sp_tree.visit(overload{ + [](BinarySeriesSplit const &series) -> PCGBinarySPDecomposition { + return PCGBinarySPDecomposition{ + pcg_binary_series_split_from_binary_series_split(series), + }; + }, + [](BinaryParallelSplit const ¶llel) -> PCGBinarySPDecomposition { + return PCGBinarySPDecomposition{ + PCGBinaryParallelSplit{ + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( + parallel.get_left_child()), + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( + parallel.get_right_child()), + }, + }; + }, + [](Node const &node) -> PCGBinarySPDecomposition { + return PCGBinarySPDecomposition{ + parallel_layer_guid_t{node}, + }; + }, + }); } std::unordered_multiset diff --git a/lib/compiler/src/compiler/unity_algorithm/graph_optimize_state.cc b/lib/compiler/src/compiler/unity_algorithm/graph_optimize_state.cc new file mode 100644 index 0000000000..22e319321b --- /dev/null +++ b/lib/compiler/src/compiler/unity_algorithm/graph_optimize_state.cc @@ -0,0 +1,61 @@ +#include "compiler/unity_algorithm/graph_optimize_state.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" + +namespace FlexFlow { + +GraphOptimizeState::GraphOptimizeState(ParallelComputationGraph const &pcg, + float runtime_with_optimal_mm) + : pcg(pcg), runtime_with_optimal_mm(runtime_with_optimal_mm) {} + +bool GraphOptimizeState::operator==(GraphOptimizeState const &other) const { + return pcgs_are_isomorphic(pcg, other.pcg); +} + +bool GraphOptimizeState::operator!=(GraphOptimizeState const &other) const { + return !(*this == other); +} + +bool GraphOptimizeState::operator<(GraphOptimizeState const &other) const { + return runtime_with_optimal_mm < other.runtime_with_optimal_mm; +} + +std::string format_as(GraphOptimizeState const &st) { + return fmt::format("", + as_dot(st.pcg), + st.runtime_with_optimal_mm); +} + +std::ostream &operator<<(std::ostream &s, GraphOptimizeState const &st) { + return (s << fmt::to_string(st)); +} + +} // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::GraphOptimizeState>::operator()( + ::FlexFlow::GraphOptimizeState const &state) const { + // TODO(@wmdi): Eventually it might be good to use a proper graph hash like + // https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash.html#networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash + size_t seed = 0; + std::vector<::FlexFlow::parallel_layer_guid_t> layers = + topological_ordering(state.pcg); + ::FlexFlow::hash_combine(seed, layers.size()); + for (::FlexFlow::parallel_layer_guid_t const &layer : layers) { + ::FlexFlow::hash_combine(seed, get_parallel_layer_attrs(state.pcg, layer)); + std::vector<::FlexFlow::parallel_tensor_guid_t> inputs = + get_incoming_tensors(state.pcg, layer); + ::FlexFlow::hash_combine(seed, inputs.size()); + for (::FlexFlow::parallel_tensor_guid_t input : inputs) { + for (size_t i = 0; i < layers.size(); ++i) { + if (get_source_layer(input) == layers.at(i)) { + ::FlexFlow::hash_combine(seed, i); + break; + } + } + } + } + return seed; +} + +} // namespace std diff --git a/lib/compiler/src/compiler/unity_algorithm/unity_algorithm.cc b/lib/compiler/src/compiler/unity_algorithm/unity_algorithm.cc new file mode 100644 index 0000000000..caaefbfdbf --- /dev/null +++ b/lib/compiler/src/compiler/unity_algorithm/unity_algorithm.cc @@ -0,0 +1,138 @@ +#include "compiler/unity_algorithm/unity_algorithm.h" +#include "compiler/machine_mapping/allowed_machine_views.h" +#include "compiler/machine_mapping/get_optimal_machine_mapping.h" +#include "compiler/machine_mapping/machine_mapping.h" +#include "compiler/machine_mapping/machine_mapping_cache.h" +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h" +#include "compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h" +#include "compiler/unity_algorithm/graph_optimize_state.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/operator_task_space.h" +#include "substitutions/apply_substitution/apply_substitution.h" +#include "substitutions/pcg_pattern.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/substitution.h" +#include "substitutions/unity_substitution_set.h" +#include "utils/containers/generate_map.h" +#include "utils/deduplicated_priority_queue.h" +#include "utils/graph/node/algorithms.h" +#include "utils/optional.h" + +namespace FlexFlow { + +/* + * Applies a substitution to all possible positions in PCG + */ +std::vector + all_pcgs_obtained_by_applying_a_substitution( + ParallelComputationGraph const &pcg, + std::vector const &substitutions) { + std::vector results; + SubParallelComputationGraph subpcg = sub_pcg_from_full_pcg(pcg); + for (Substitution const &substitution : substitutions) { + for (PCGPatternMatch const &pattern_match : + find_pattern_matches(substitution.pcg_pattern, subpcg)) { + SubParallelComputationGraph subpcg_from_substitution = + apply_substitution(subpcg, substitution, pattern_match); + results.push_back( + pcg_from_sub_pcg_by_dropping_inputs(subpcg_from_substitution)); + } + } + return results; +} + +SearchResult graph_optimize(ParallelComputationGraph &pcg, + CostEstimator const &cost_estimator, + MachineSpecification const &resources, + UnitySearchConfig const &search_config) { + + std::vector substitutions = get_substitution_set(resources); + + MachineMappingCache cached_subgraph_costs = empty_machine_mapping_cache(); + DeduplicatedPriorityQueue candidates; + + MachineMappingContext context = MachineMappingContext{ + /*cost_estimator=*/cost_estimator, + /*allowed_machine_views=*/ + [&](UnmappedOpCostEstimateKey const &key, + MachineSpecification const &resources) + -> std::unordered_set { + return get_allowed_machine_views( + resources, key.op_task_space, DeviceType::GPU); + }, + }; + + auto optimize_pcg = [&](ParallelComputationGraph const &pcg) + -> std::pair> { + PCGBinarySPDecomposition sp_decomp = + expect(get_pcg_balanced_binary_sp_decomposition(pcg), + "Failed to get SP decomposition of PCG"); + + MachineMappingProblemTree problem_tree = + get_machine_mapping_problem_tree(pcg, sp_decomp); + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers(get_all_leaf_paths(problem_tree)); + + MachineMappingResult mm_result = get_optimal_machine_mapping( + cached_subgraph_costs, context, problem_tree, resources, constraints); + + return { + GraphOptimizeState{ + /*pcg=*/pcg, + /*runtime_with_optimal_mm=*/get_runtime_cost(mm_result), + }, + get_machine_mapping_from_machine_mapping_result(sp_decomp, mm_result), + }; + }; + + GraphOptimizeState best_state = optimize_pcg(pcg).first; + candidates.push(best_state); + + for (int iteration = 0; + !candidates.empty() && iteration < search_config.budget; + ++iteration) { + GraphOptimizeState current_state = candidates.top(); + candidates.pop(); + + if (current_state < best_state) { + best_state = current_state; + } else if (current_state.runtime_with_optimal_mm > + best_state.runtime_with_optimal_mm * search_config.alpha) { + continue; + } + + for (ParallelComputationGraph const &new_pcg : + all_pcgs_obtained_by_applying_a_substitution(current_state.pcg, + substitutions)) { + std::optional new_pcg_optimize_result = + optimize_pcg(new_pcg).first; + if (new_pcg_optimize_result == std::nullopt) { + continue; + } + GraphOptimizeState new_state = new_pcg_optimize_result.value(); + if (new_state.runtime_with_optimal_mm <= search_config.threshold && + get_nodes(new_pcg.raw_graph).size() <= search_config.max_num_ops) { + candidates.push(new_state); + } + } + } + + std::optional best_mapping = + optimize_pcg(best_state.pcg).second; + + if (best_mapping == std::nullopt) { + throw std::runtime_error("Failed to find any solutions"); + } + + return SearchResult{ + /*pcg=*/best_state.pcg, + /*machine_mapping=*/best_mapping.value(), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc deleted file mode 100644 index 86a211c535..0000000000 --- a/lib/compiler/src/unity_algorithm.cc +++ /dev/null @@ -1,93 +0,0 @@ -#include "compiler/unity_algorithm.h" -#include "compiler/graph_optimize_state.h" -#include "compiler/machine_mapping/get_optimal_machine_mapping.h" -#include "pcg/machine_specification.dtg.h" -#include "substitutions/substitution.h" -#include "utils/deduplicated_priority_queue.h" -#include "utils/graph/node/algorithms.h" -namespace FlexFlow { - -/* - * Gets all substitutions applicable to a PCG - */ -std::vector - get_all_applicable_substitutions(ParallelComputationGraph const &pcg) { - NOT_IMPLEMENTED(); -} - -/* - * Applies a substitution to all possible positions in PCG - */ -std::vector - apply_substitution(ParallelComputationGraph const &pcg, - Substitution const &) { - NOT_IMPLEMENTED(); -} - -GraphOptimizeResult graph_optimize( - ParallelComputationGraph &pcg, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - OptimizerConfig const &opt_config) { - NOT_IMPLEMENTED(); - - // std::vector substitutions = - // get_all_applicable_substitutions(pcg); - // - // MachineMappingCache cached_subgraph_costs; - // DeduplicatedPriorityQueue candidates; - // - // MachineMappingResult original_pcg_cost = - // get_optimal_machine_mapping(pcg, - // allowed_machine_views, - // cost_estimator, - // resources, - // cached_subgraph_costs); - // - // GraphOptimizeState initial_state = { - // GraphOptimizeResult(pcg, original_pcg_cost.machine_mapping), - // original_pcg_cost.runtime}; - // - // GraphOptimizeState best_state = initial_state; - // candidates.push(initial_state); - // - // for (int iteration = 0; !candidates.empty() && iteration < - // opt_config.budget; - // ++iteration) { - // GraphOptimizeState current_state = candidates.top(); - // candidates.pop(); - // - // if (current_state.runtime < best_state.runtime) { - // best_state = current_state; - // } else if (current_state.runtime > best_state.runtime * opt_config.alpha) - // { - // continue; - // } - // - // for (Substitution const &substitution : substitutions) { - // for (ParallelComputationGraph const &new_pcg : apply_substitution( - // current_state.graph_optimize_result.pcg, substitution)) { - // MachineMappingResult new_pcg_cost = - // get_optimal_machine_mapping(new_pcg, - // allowed_machine_views, - // cost_estimator, - // resources, - // cached_subgraph_costs); - // GraphOptimizeState new_state{ - // GraphOptimizeResult(new_pcg, new_pcg_cost.machine_mapping), - // new_pcg_cost.runtime}; - // if (new_pcg_cost.runtime <= opt_config.threshold && - // get_nodes(new_pcg.raw_graph).size() <= opt_config.max_num_ops) { - // candidates.push(new_state); - // } - // } - // } - // } - - // return best_state.graph_optimize_result; -} - -} // namespace FlexFlow diff --git a/lib/compiler/test/src/allowed_machine_views.cc b/lib/compiler/test/src/allowed_machine_views.cc deleted file mode 100644 index 817cc80700..0000000000 --- a/lib/compiler/test/src/allowed_machine_views.cc +++ /dev/null @@ -1,110 +0,0 @@ -#include "compiler/allowed_machine_views.h" -#include "doctest/doctest.h" -#include "utils/containers/extend.h" -#include "utils/containers/range.h" -#include "utils/containers/transform.h" -#include "utils/containers/unordered_set_of.h" -#include "utils/containers/zip.h" -#include "utils/fmt/unordered_set.h" - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - - TEST_CASE("get_allowed_machine_views") { - - SUBCASE("1 degree of parallelism") { - MachineSpecification ms = MachineSpecification{ - /*num_nodes=*/1_n, - /*num_cpus_per_node=*/5_n, - /*num_gpus_per_node=*/5_n, - /*inter_node_bandwidth=*/0, - /*intra_node_bandwidth=*/0, - }; - - OperatorTaskSpace task = OperatorTaskSpace{{3_n}}; - - std::unordered_set correct = { - MachineView{ - MachineSpaceCoordinate{ - /*node_idx=*/0_n, /*device_idx=*/0_n, DeviceType::GPU}, - {MachineViewDimension{stride_t{1_n}, - MachineSpecificationDimension::INTRA_NODE}}, - }, - - MachineView{ - MachineSpaceCoordinate{ - /*node_idx=*/0_n, /*device_idx=*/1_n, DeviceType::GPU}, - {MachineViewDimension{stride_t{1_n}, - MachineSpecificationDimension::INTRA_NODE}}, - }, - MachineView{ - MachineSpaceCoordinate{ - /*node_idx=*/0_n, /*device_idx=*/2_n, DeviceType::GPU}, - {MachineViewDimension{stride_t{1_n}, - MachineSpecificationDimension::INTRA_NODE}}, - }, - MachineView{ - MachineSpaceCoordinate{ - /*node_idx=*/0_n, /*device_idx=*/0_n, DeviceType::GPU}, - {MachineViewDimension{stride_t{2_n}, - MachineSpecificationDimension::INTRA_NODE}}, - }, - }; - - std::unordered_set result = - get_allowed_machine_views(ms, task, DeviceType::GPU); - - CHECK(correct == result); - } - - SUBCASE("2 degrees of parallelism") { - - MachineSpecification ms = MachineSpecification{ - /*num_nodes=*/3_n, - /*num_cpus_per_node=*/3_n, - /*num_gpus_per_node=*/3_n, - /*inter_node_bandwidth=*/0, - /*intra_node_bandwidth=*/0, - }; - OperatorTaskSpace task = OperatorTaskSpace{{2_n, 3_n}}; - - auto make_2d_view = [&](nonnegative_int start_node_idx, - nonnegative_int start_device_idx, - nonnegative_int stride1, - nonnegative_int stride2, - MachineSpecificationDimension m1, - MachineSpecificationDimension m2) { - return MachineView{ - MachineSpaceCoordinate{ - start_node_idx, start_device_idx, DeviceType::GPU}, - {MachineViewDimension{stride_t{stride1}, m1}, - MachineViewDimension{stride_t{stride2}, m2}}, - }; - }; - - auto intra = MachineSpecificationDimension::INTRA_NODE; - auto inter = MachineSpecificationDimension::INTER_NODE; - std::unordered_set correct = { - make_2d_view( - 0_n, 0_n, /*stride1=*/1_n, /*stride2=*/1_n, inter, intra), - make_2d_view( - 1_n, 0_n, /*stride1=*/1_n, /*stride2=*/1_n, inter, intra), - make_2d_view( - 0_n, 0_n, /*stride1=*/2_n, /*stride2=*/1_n, inter, intra), - - make_2d_view( - 0_n, 0_n, /*stride1=*/1_n, /*stride2=*/1_n, intra, inter), - make_2d_view( - 0_n, 1_n, /*stride1=*/1_n, /*stride2=*/1_n, intra, inter), - make_2d_view( - 0_n, 0_n, /*stride1=*/2_n, /*stride2=*/1_n, intra, inter), - }; - - std::unordered_set result = - get_allowed_machine_views(ms, task, DeviceType::GPU); - - CHECK(correct == result); - } - } -} diff --git a/lib/compiler/test/src/compiler/machine_mapping/allowed_machine_views.cc b/lib/compiler/test/src/compiler/machine_mapping/allowed_machine_views.cc new file mode 100644 index 0000000000..f176621a18 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/allowed_machine_views.cc @@ -0,0 +1,156 @@ +#include "compiler/machine_mapping/allowed_machine_views.h" +#include "doctest/doctest.h" +#include "utils/containers/extend.h" +#include "utils/containers/range.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/zip.h" +#include "utils/fmt/unordered_set.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("get_allowed_machine_views") { + + auto make_2d_view = [&](nonnegative_int start_node_idx, + nonnegative_int start_device_idx, + nonnegative_int stride_1, + nonnegative_int stride_2, + MachineSpecificationDimension m1, + MachineSpecificationDimension m2) { + return MachineView{ + MachineSpaceCoordinate{ + start_node_idx, start_device_idx, DeviceType::GPU}, + {MachineViewDimension{stride_t{stride_1}, m1}, + MachineViewDimension{stride_t{stride_2}, m2}}, + }; + }; + auto intra = MachineSpecificationDimension::INTRA_NODE; + auto inter = MachineSpecificationDimension::INTER_NODE; + + SUBCASE("1 degree of parallelism") { + MachineSpecification ms = MachineSpecification{ + /*num_nodes=*/1_n, + /*num_cpus_per_node=*/5_n, + /*num_gpus_per_node=*/5_n, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0, + }; + + OperatorTaskSpace task = OperatorTaskSpace{{3_n}}; + + std::unordered_set correct = { + MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0_n, /*device_idx=*/0_n, DeviceType::GPU}, + {MachineViewDimension{stride_t{1_n}, + MachineSpecificationDimension::INTRA_NODE}}, + }, + + MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0_n, /*device_idx=*/1_n, DeviceType::GPU}, + {MachineViewDimension{stride_t{1_n}, + MachineSpecificationDimension::INTRA_NODE}}, + }, + MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0_n, /*device_idx=*/2_n, DeviceType::GPU}, + {MachineViewDimension{stride_t{1_n}, + MachineSpecificationDimension::INTRA_NODE}}, + }, + MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0_n, /*device_idx=*/0_n, DeviceType::GPU}, + {MachineViewDimension{stride_t{2_n}, + MachineSpecificationDimension::INTRA_NODE}}, + }, + }; + + std::unordered_set result = + get_allowed_machine_views(ms, task, DeviceType::GPU); + + CHECK(correct == result); + } + + SUBCASE("2 degrees of parallelism") { + + MachineSpecification ms = MachineSpecification{ + /*num_nodes=*/3_n, + /*num_cpus_per_node=*/3_n, + /*num_gpus_per_node=*/3_n, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0, + }; + OperatorTaskSpace task = OperatorTaskSpace{{2_n, 3_n}}; + + std::unordered_set correct = { + make_2d_view( + 0_n, 0_n, /*stride_1=*/1_n, /*stride_2=*/1_n, inter, intra), + make_2d_view( + 1_n, 0_n, /*stride_1=*/1_n, /*stride_2=*/1_n, inter, intra), + make_2d_view( + 0_n, 0_n, /*stride_1=*/2_n, /*stride_2=*/1_n, inter, intra), + + make_2d_view( + 0_n, 0_n, /*stride_1=*/1_n, /*stride_2=*/1_n, intra, inter), + make_2d_view( + 0_n, 1_n, /*stride_1=*/1_n, /*stride_2=*/1_n, intra, inter), + make_2d_view( + 0_n, 0_n, /*stride_1=*/2_n, /*stride_2=*/1_n, intra, inter), + }; + + std::unordered_set result = + get_allowed_machine_views(ms, task, DeviceType::GPU); + + CHECK(correct == result); + } + + SUBCASE("2D operator task space, dimensions (1,1)") { + MachineSpecification full_machine_spec = MachineSpecification{ + /*num_nodes=*/nonnegative_int{2}, + /*num_cpus_per_node=*/nonnegative_int{1}, + /*num_gpus_per_node=*/nonnegative_int{1}, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, + }; + OperatorTaskSpace task = OperatorTaskSpace{{1_n, 1_n}}; + + std::unordered_set result = + get_allowed_machine_views(full_machine_spec, task, DeviceType::GPU); + + std::unordered_set correct = { + make_2d_view( + 0_n, 0_n, /*stride_1=*/1_n, /*stride_2=*/1_n, intra, intra), + make_2d_view( + 1_n, 0_n, /*stride_1=*/1_n, /*stride_2=*/1_n, intra, intra)}; + CHECK(correct == result); + } + + SUBCASE("2D operator task space, dimensions (2,1)") { + MachineSpecification full_machine_spec = MachineSpecification{ + /*num_nodes=*/nonnegative_int{2}, + /*num_cpus_per_node=*/nonnegative_int{2}, + /*num_gpus_per_node=*/nonnegative_int{2}, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, + }; + OperatorTaskSpace task = OperatorTaskSpace{{1_n, 2_n}}; + + std::unordered_set result = + get_allowed_machine_views(full_machine_spec, task, DeviceType::GPU); + + std::unordered_set correct = { + make_2d_view( + 0_n, 0_n, /*stride_1=*/1_n, /*stride_2=*/1_n, intra, intra), + make_2d_view( + 0_n, 0_n, /*stride_1=*/1_n, /*stride_2=*/1_n, intra, inter), + make_2d_view( + 1_n, 0_n, /*stride_1=*/1_n, /*stride_2=*/1_n, intra, intra), + make_2d_view( + 0_n, 1_n, /*stride_1=*/1_n, /*stride_2=*/1_n, intra, inter)}; + CHECK(correct == result); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index e506dea1d7..a45227011c 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -109,11 +109,14 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; + OperatorTaskSpace fake_op_task_space = OperatorTaskSpace{{}}; + UnmappedOpCostEstimateKey k1 = UnmappedOpCostEstimateKey{ /*op_attrs=*/PCGOperatorAttrs{InputAttrs{tensor_shape}}, /*input_shapes=*/{}, /*weight_shapes=*/{}, /*output_shapes=*/{}, + /*op_task_space=*/fake_op_task_space, }; UnmappedOpCostEstimateKey k2 = UnmappedOpCostEstimateKey{ @@ -126,6 +129,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*input_shapes=*/{}, /*weight_shapes=*/{}, /*output_shapes=*/{}, + /*op_task_space=*/fake_op_task_space, }; ParallelTensorShape par_tensor_shape = lift_to_parallel(tensor_shape); diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc index 048f1ddcac..9059950742 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -1,8 +1,15 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h" #include "op-attrs/parallel_tensor_shape.h" +#include "pcg/computation_graph_builder.h" +#include "pcg/operator_task_space.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "pcg/pcg_from_computation_graph.h" +#include "utils/containers/extend.h" #include "utils/containers/get_only.h" +#include "utils/containers/vector_of.h" #include using namespace ::FlexFlow; @@ -90,6 +97,14 @@ TEST_SUITE(FF_TEST_SUITE) { PCGOperatorAttrs input_attrs = PCGOperatorAttrs{InputAttrs{input_shape}}; + auto make_operator_task_space = [&](ParallelTensorShape const &shape) { + std::vector degrees; + extend(degrees, vector_of(ff_ordered_shard_degrees(shape))); + degrees.push_back(get_sum_degree(shape)); + degrees.push_back(get_discard_copy_degree(shape)); + return OperatorTaskSpace{degrees}; + }; + auto make_input_key = [&](ParallelTensorShape const ¶llel_tensor_shape) { return UnmappedOpCostEstimateKey{ @@ -97,6 +112,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*input_shapes=*/{}, /*weight_shapes=*/{}, /*output_shapes=*/{parallel_tensor_shape}, + /*op_task_space=*/make_operator_task_space(parallel_tensor_shape), }; }; @@ -143,11 +159,15 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t relu_layer = relu_added.parallel_layer; parallel_tensor_guid_t relu_output = get_only(relu_added.outputs); + OperatorTaskSpace relu_task_space = + get_operator_task_space(pcg, relu_layer); + UnmappedOpCostEstimateKey relu_key = UnmappedOpCostEstimateKey{ /*op_attrs=*/relu_attrs, /*input_shapes=*/{par_input_shape}, /*weight_shapes=*/{}, /*output_shapes=*/{relu_output_shape}, + /*op_task_space=*/relu_task_space, }; PCGBinarySPDecomposition sp_decomposition = pcg_make_series( @@ -228,11 +248,14 @@ TEST_SUITE(FF_TEST_SUITE) { {input1_tensor, input2_tensor}, {}); parallel_layer_guid_t ew_op_layer = ew_op_added.parallel_layer; + OperatorTaskSpace ew_op_task_space = + get_operator_task_space(pcg, ew_op_layer); UnmappedOpCostEstimateKey ew_op_key = UnmappedOpCostEstimateKey{ /*op_attrs=*/ew_op_attrs, /*input_shapes=*/{par_input_shape, par_input_shape}, /*weight_shapes=*/{}, /*output_shapes=*/{ew_op_output_shape}, + /*op_task_space=*/ew_op_task_space, }; PCGBinarySPDecomposition sp_decomposition = @@ -280,4 +303,43 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } } + + TEST_CASE("from pcg") { + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + TensorShape input_tensor_shape = TensorShape{ + TensorDims{ + FFOrdered{nonnegative_int{32}, + nonnegative_int{64}}, + }, + DataType::FLOAT, + }; + tensor_guid_t t = b.create_input(input_tensor_shape, CreateGrad::YES); + t = b.dense(t, + /*outDim=*/nonnegative_int{16}, + /*activation=*/std::nullopt); + t = b.gelu(t); + t = b.dense(t, + /*outDim=*/nonnegative_int{12}, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt); + t = b.relu(t); + t = b.dense(t, + /*outDim=*/nonnegative_int{8}, + /*activation=*/Activation::RELU); + return b.computation_graph; + }(); + + ParallelComputationGraph pcg = pcg_from_computation_graph(cg); + + PCGBinarySPDecomposition sp_decomp = + expect(get_pcg_balanced_binary_sp_decomposition(pcg), + "Failed to get SP decomposition of PCG"); + + MachineMappingProblemTree problem_tree = + get_machine_mapping_problem_tree(pcg, sp_decomp); + } } diff --git a/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc index 8ae1ebe753..f049f4b288 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc @@ -99,6 +99,7 @@ TEST_SUITE(FF_TEST_SUITE) { } }; + OperatorTaskSpace fake_op_task_space = OperatorTaskSpace{{}}; TensorShape tensor_shape = TensorShape{ TensorDims{ FFOrdered{ @@ -116,6 +117,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*input_shapes=*/{}, /*weight_shapes=*/{}, /*output_shapes=*/{}, + /*op_task_space=*/fake_op_task_space, }; UnmappedOpCostEstimateKey k2 = UnmappedOpCostEstimateKey{ @@ -128,6 +130,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*input_shapes=*/{}, /*weight_shapes=*/{}, /*output_shapes=*/{}, + /*op_task_space=*/fake_op_task_space, }; AbstractedTensorSetMovement movement1 = AbstractedTensorSetMovement{{ diff --git a/lib/compiler/test/src/graph_optimize_state.cc b/lib/compiler/test/src/compiler/unity_algorithm/graph_optimize_state.cc similarity index 68% rename from lib/compiler/test/src/graph_optimize_state.cc rename to lib/compiler/test/src/compiler/unity_algorithm/graph_optimize_state.cc index 5c00ce1558..3b146be93f 100644 --- a/lib/compiler/test/src/graph_optimize_state.cc +++ b/lib/compiler/test/src/compiler/unity_algorithm/graph_optimize_state.cc @@ -1,4 +1,4 @@ -#include "compiler/graph_optimize_state.h" +#include "compiler/unity_algorithm/graph_optimize_state.h" #include "doctest/doctest.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" @@ -15,24 +15,6 @@ TEST_SUITE(FF_TEST_SUITE) { }, DataType::FLOAT, }; - // ParallelTensorShape input_shape = - // ParallelTensorShape{ParallelTensorDims{ - // FFOrdered{ - // ShardParallelDim{32_n, 2_n}, - // ShardParallelDim{16_n, 1_n}, - // }, - // ReplicaParallelDimSet{ - // SumDegree{1_n}, - // DiscardCopyDegree{1_n}, - // }, - // }, - // DataType::FLOAT}; - - // `machine_mapping` is determined by the PCG and the device mapping - // algorithm, and `runtime` is determined by the PCG and the device mapping, - // so their values here do not matter. - std::unordered_map empty_machine_views; - MachineMapping empty_machine_mapping(empty_machine_views); InitializerAttrs zero_init = InitializerAttrs{ZeroInitializerAttrs{}}; @@ -70,13 +52,12 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelComputationGraph pcg2 = create_pcg(); GraphOptimizeState state1 = GraphOptimizeState{ - GraphOptimizeResult{pcg1, empty_machine_mapping}, - 0, + pcg1, + .0, }; - GraphOptimizeState state2 = GraphOptimizeState{ - GraphOptimizeResult{pcg2, empty_machine_mapping}, - 0, + pcg2, + .0, }; CHECK(state1 == state2); @@ -100,16 +81,30 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelComputationGraph pcg_ = builder_.pcg; GraphOptimizeState state1 = GraphOptimizeState{ - GraphOptimizeResult{pcg1, empty_machine_mapping}, - 0, + pcg1, + .0, }; GraphOptimizeState state_ = GraphOptimizeState{ - GraphOptimizeResult{pcg_, empty_machine_mapping}, - 0, + pcg_, + .0, }; CHECK_FALSE(state1 == state_); } } + + TEST_CASE("GraphOptimizeState::operator<") { + ParallelComputationGraph pcg1 = empty_parallel_computation_graph(); + ParallelComputationGraph pcg2 = empty_parallel_computation_graph(); + GraphOptimizeState state1 = GraphOptimizeState{ + pcg1, + 1.0, + }; + GraphOptimizeState state2 = GraphOptimizeState{ + pcg2, + 2.0, + }; + CHECK(state1 < state2); + } } diff --git a/lib/compiler/test/src/compiler/unity_algorithm/unity_algorithm.cc b/lib/compiler/test/src/compiler/unity_algorithm/unity_algorithm.cc new file mode 100644 index 0000000000..73b4a7c80b --- /dev/null +++ b/lib/compiler/test/src/compiler/unity_algorithm/unity_algorithm.cc @@ -0,0 +1,88 @@ +#include "compiler/unity_algorithm/unity_algorithm.h" +#include "../cost_estimator_for_test.h" +#include "doctest/doctest.h" +#include "op-attrs/parallel_tensor_dims.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/replica_type.dtg.h" +#include "op-attrs/shard_parallel_dim.h" +#include "pcg/computation_graph_builder.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "pcg/pcg_from_computation_graph.h" +#include "utils/integer_conversions.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("graph_optimize") { + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + TensorShape input_tensor_shape = TensorShape{ + TensorDims{ + FFOrdered{nonnegative_int{32}, + nonnegative_int{64}}, + }, + DataType::FLOAT, + }; + tensor_guid_t t = b.create_input(input_tensor_shape, CreateGrad::YES); + t = b.dense(t, + /*outDim=*/nonnegative_int{16}, + /*activation=*/std::nullopt); + t = b.gelu(t); + t = b.dense(t, + /*outDim=*/nonnegative_int{12}, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt); + t = b.relu(t); + t = b.dense(t, + /*outDim=*/nonnegative_int{8}, + /*activation=*/Activation::RELU); + return b.computation_graph; + }(); + + ParallelComputationGraph pcg = pcg_from_computation_graph(cg); + + CostEstimator cost_estimator = make_fake_cost_estimator( + [](OpCostEstimateKey const &k) { + return OpCostMetrics{ + /*forward_runtime=*/1.0, + /*backward_runtime=*/2.0, + /*memory=*/nonnegative_int{1}, + }; + }, + [](TensorSetMovement const &) { return 1.0; }); + + MachineSpecification full_machine_spec = MachineSpecification{ + /*num_nodes=*/nonnegative_int{2}, + /*num_cpus_per_node=*/nonnegative_int{1}, + /*num_gpus_per_node=*/nonnegative_int{1}, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, + }; + + SUBCASE("do not apply substitution") { + UnitySearchConfig search_config = UnitySearchConfig{ + /*alpha=*/1.0, + /*budget=*/0, + /*threshold=*/1000.0, + /*max_num_ops=*/100, + }; + SearchResult result = + graph_optimize(pcg, cost_estimator, full_machine_spec, search_config); + CHECK(pcgs_are_isomorphic(pcg, result.pcg)); + } + + SUBCASE("apply substitution") { + UnitySearchConfig search_config = UnitySearchConfig{ + /*alpha=*/1.0, + /*budget=*/1, + /*threshold=*/1000.0, + /*max_num_ops=*/100, + }; + SearchResult result = + graph_optimize(pcg, cost_estimator, full_machine_spec, search_config); + } + } +} diff --git a/lib/compiler/test/src/unity_algorithm.cc b/lib/compiler/test/src/unity_algorithm.cc deleted file mode 100644 index 8ff0978ea5..0000000000 --- a/lib/compiler/test/src/unity_algorithm.cc +++ /dev/null @@ -1,26 +0,0 @@ -#include "compiler/unity_algorithm.h" -#include "doctest/doctest.h" - -TEST_SUITE(FF_TEST_SUITE) { - // Rapidcheck does not work for now - // TEST_CASE("graph_optimize") { - // RC_SUBCASE([](ComputationGraph const &g, - // float alpha, - // int budget, - // float threshold, - // int max_num_ops) { - // Strategy s = graph_optimize( - // g, - // TestCostEstimator{}, - // MachineSpecification{1, 1, 4, 0.1, 0.2}, - // [](Operator const &, MachineSpecification const &) { - // return std::unordered_set{make_1d_machine_view(0, 1, - // 1)}; - // }, - // OptimizerConfig{alpha, budget, threshold, max_num_ops}); - // RC_ASSERT(get_nodes(s.pcg).size() > 0); - // RC_ASSERT(s.machine_mapping.runtime > 0); - // RC_ASSERT(keys(s.machine_mapping.machine_views) == get_nodes(s.pcg)); - // }); - // } -} diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 3542e73dea..f820c56d61 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H #define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H +#include "pcg/computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" diff --git a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc index 194ae49255..f39b771364 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc @@ -16,6 +16,16 @@ bool operator_satisfies_constraint( switch (constraint.constraint_type) { case ConstraintType::EQUAL: return expr_val.value() == constraint.attribute_value; + case ConstraintType::DIVISIBLE_BY: { + if (expr_val.value().has() && + constraint.attribute_value.has()) { + return expr_val.value().get() % + constraint.attribute_value.get() == + 0; + } + throw mk_runtime_error( + "DIVISIBLE_BY constraint requires nonnegative_int values"); + } default: throw mk_runtime_error( fmt::format("Unknown constraint type {}", diff --git a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc index 974bfcabc0..cc0af12c91 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc @@ -12,6 +12,16 @@ bool parallel_tensor_satisfies_constraint( switch (constraint.constraint_type) { case ConstraintType::EQUAL: return expr_val == constraint.attribute_value; + case ConstraintType::DIVISIBLE_BY: { + if (expr_val.has() && + constraint.attribute_value.has()) { + return expr_val.get() % + constraint.attribute_value.get() == + 0; + } + throw mk_runtime_error( + "DIVISIBLE_BY constraint requires nonnegative_int values"); + } default: throw mk_runtime_error( fmt::format("Unknown constraint type {}", diff --git a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc index a7ebc0bff7..9d8e4bc259 100644 --- a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc +++ b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc @@ -11,6 +11,7 @@ #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" +#include "utils/overload.h" namespace FlexFlow { @@ -67,6 +68,27 @@ static std::optional return match; } +MatchAdditionalCriterion additional_criterion_for_subpattern( + MatchAdditionalCriterion const &full_additional_criterion, + bidict const + &full_pattern_values_to_subpattern_inputs) { + return MatchAdditionalCriterion{ + full_additional_criterion.node_criterion, + [&](PatternValue const &patternValue, OpenDataflowValue const &pcgValue) { + return patternValue.visit( + overload{[&](PatternNodeOutput const &) -> bool { + return full_additional_criterion.value_criterion( + patternValue, pcgValue); + }, + [&](PatternInput const &i) -> bool { + PatternValue full_pattern_value = + full_pattern_values_to_subpattern_inputs.at_r(i); + return full_additional_criterion.value_criterion( + full_pattern_value, pcgValue); + }}); + }}; +} + std::vector find_pattern_matches(UnlabelledGraphPattern const &pattern, OpenDataflowGraphView const &graph, @@ -87,10 +109,18 @@ std::vector PatternSplitResult subpatterns = apply_split(pattern, split); std::vector prefix_matches = find_pattern_matches( - subpatterns.subpattern_1, graph, additional_criterion); + subpatterns.subpattern_1, + graph, + additional_criterion_for_subpattern( + additional_criterion, + subpatterns.full_pattern_values_to_subpattern_1_inputs)); std::vector postfix_matches = find_pattern_matches( - subpatterns.subpattern_2, graph, additional_criterion); + subpatterns.subpattern_2, + graph, + additional_criterion_for_subpattern( + additional_criterion, + subpatterns.full_pattern_values_to_subpattern_2_inputs)); for (UnlabelledDataflowGraphPatternMatch const &prefix_match : prefix_matches) { diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc index 304bb8cf46..c7b03e24f2 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -7,10 +7,13 @@ #include "substitutions/unlabelled/unlabelled_graph_pattern.h" #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" +#include "utils/containers/is_subseteq_of.h" #include "utils/containers/keys.h" #include "utils/containers/transform.h" +#include "utils/containers/values.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/as_dot.h" #include "utils/graph/open_dataflow_graph/algorithms/get_edges.h" #include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" @@ -18,6 +21,7 @@ #include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" #include "utils/overload.h" +#include #include namespace FlexFlow { @@ -46,8 +50,13 @@ struct SubgraphConcreteFromPattern { } OpenDataflowValue operator()(PatternInput const &i) const { - return OpenDataflowValue{full_graph_values_to_subgraph_inputs.at_l( - match.input_assignment.at(i))}; + OpenDataflowValue mapped_input = match.input_assignment.at(i); + if (full_graph_values_to_subgraph_inputs.contains_l(mapped_input)) { + return OpenDataflowValue{ + full_graph_values_to_subgraph_inputs.at_l(mapped_input)}; + } else { + return mapped_input; + } } OpenDataflowEdge operator()(InputPatternEdge const &e) const { @@ -148,11 +157,27 @@ bool unlabelled_pattern_does_match( UnlabelledDataflowGraphPatternMatch const &match, MatchAdditionalCriterion const &additional_criterion) { + std::unordered_set matched_by_pattern_inputs = + unordered_set_of(values(match.input_assignment)); + + ASSERT(left_entries(match.node_assignment) == get_nodes(pattern)); + ASSERT( + is_subseteq_of(right_entries(match.node_assignment), get_nodes(graph))); + ASSERT(keys(match.input_assignment) == get_graph_inputs(pattern)); + ASSERT(is_subseteq_of(matched_by_pattern_inputs, + get_open_dataflow_values(graph))); + OpenDataflowSubgraphResult subgraph_result = subgraph_matched(graph, match); OpenDataflowGraphView matched_subgraph = subgraph_result.graph; - assert(left_entries(match.node_assignment) == get_nodes(pattern)); - assert(right_entries(match.node_assignment) == get_nodes(matched_subgraph)); + std::unordered_set full_values_split_by_subgraph = + left_entries(subgraph_result.full_graph_values_to_subgraph_inputs); + + ASSERT(right_entries(match.node_assignment) == get_nodes(matched_subgraph)); + ASSERT(is_subseteq_of(full_values_split_by_subgraph, + get_open_dataflow_values(graph)), + full_values_split_by_subgraph, + get_open_dataflow_values(graph)); MatchAdditionalCriterion through_subgraph_operation = MatchAdditionalCriterion{ diff --git a/lib/substitutions/test/src/substitutions/pcg_pattern.cc b/lib/substitutions/test/src/substitutions/pcg_pattern.cc index 8ba1fee873..ccc83e12d6 100644 --- a/lib/substitutions/test/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/test/src/substitutions/pcg_pattern.cc @@ -152,5 +152,118 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_set correct = {match1, match2}; CHECK(result == correct); + + SUBCASE("pcg is a chain") { + ParallelComputationGraphBuilder builder; + + nonnegative_int batch_size = 16_n; + nonnegative_int batch_degree = 2_n; + nonnegative_int num_channels = 24_n; + + TensorShape a_shape = TensorShape{ + TensorDims{ + FFOrdered{ + batch_size, + num_channels, + }, + }, + DataType::FLOAT, + }; + + std::string a_name = "a"; + + parallel_tensor_guid_t a_tensor = builder.create_input_tensor(a_shape); + a_tensor = + builder.parallel_partition(a_tensor, ff_dim_t{0_n}, batch_degree); + + nonnegative_int outDim = 16_n; + std::string x_matmul_name = "x_matmul"; + std::string y_matmul_name = "y_matmul"; + parallel_tensor_guid_t t0 = + builder.dense(a_tensor, + outDim, + /*activation=*/std::nullopt, + /*use_bias=*/false, + DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + x_matmul_name); + parallel_tensor_guid_t t1 = + builder.dense(t0, + outDim, + /*activation=*/std::nullopt, + /*use_bias=*/false, + DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + y_matmul_name); + parallel_tensor_guid_t t2 = + builder.dense(t1, + outDim, + /*activation=*/std::nullopt, + /*use_bias=*/false, + DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt); + parallel_tensor_guid_t t3 = + builder.dense(t2, + outDim, + /*activation=*/std::nullopt, + /*use_bias=*/false, + DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt); + ParallelComputationGraph pcg = builder.pcg; + + LabelledOpenDataflowGraph + g = LabelledOpenDataflowGraph:: + create>(); + + TensorAttributePattern pattern_tensor_a = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_b = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_c = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_x = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_y = + tensor_attribute_pattern_match_all(); + + OperatorAttributePattern op_pattern_1 = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::LINEAR), + }}; + + OperatorAttributePattern op_pattern_2 = op_pattern_1; + + DataflowGraphInput pt_a = g.add_input(pattern_tensor_a); + DataflowGraphInput pt_b = g.add_input(pattern_tensor_b); + DataflowGraphInput pt_c = g.add_input(pattern_tensor_c); + + NodeAddedResult op_pattern_1_added = + g.add_node(op_pattern_1, + {OpenDataflowValue{pt_a}, OpenDataflowValue{pt_b}}, + {pattern_tensor_x}); + PatternNode op_pattern_1_node = PatternNode{op_pattern_1_added.node}; + OpenDataflowValue pt_x = + OpenDataflowValue{get_only(op_pattern_1_added.outputs)}; + + NodeAddedResult op_pattern_2_added = + g.add_node(op_pattern_2, + {OpenDataflowValue{pt_x}, OpenDataflowValue{pt_c}}, + {pattern_tensor_y}); + PatternNode op_pattern_2_node = PatternNode{op_pattern_2_added.node}; + + PCGPattern pattern = PCGPattern{g}; + + std::unordered_set result = unordered_set_of( + find_pattern_matches(pattern, sub_pcg_from_full_pcg(pcg))); + + CHECK(result.size() == 3); + } } } diff --git a/lib/utils/include/utils/full_binary_tree/as_dot.h b/lib/utils/include/utils/full_binary_tree/as_dot.h new file mode 100644 index 0000000000..e104d05e06 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/as_dot.h @@ -0,0 +1,81 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_AS_DOT_H + +#include "utils/containers/get_only.h" +#include "utils/dot_file.h" +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" +#include +#include +#include + +namespace FlexFlow { + +template +LabelledDataflowGraph as_labelled_dataflow_graph( + Tree const &tree, + FullBinaryTreeImplementation const &impl, + std::function const &get_parent_label, + std::function const &get_leaf_label) { + auto g = LabelledDataflowGraph::template create< + UnorderedSetLabelledOpenDataflowGraph>(); + + FullBinaryTreeVisitor visitor = + FullBinaryTreeVisitor{ + [&](Parent const &parent) -> DataflowOutput { + DataflowOutput left_child_output = + visit(impl.get_left_child(parent), impl, visitor); + DataflowOutput right_child_output = + visit(impl.get_right_child(parent), impl, visitor); + NodeLabel parent_label = get_parent_label(parent); + NodeAddedResult parent_added = + g.add_node(parent_label, + {left_child_output, right_child_output}, + {std::monostate{}}); + return get_only(parent_added.outputs); + }, + [&](Leaf const &leaf) -> DataflowOutput { + NodeLabel leaf_label = get_leaf_label(leaf); + NodeAddedResult leaf_added = + g.add_node(leaf_label, {}, {std::monostate{}}); + return get_only(leaf_added.outputs); + }, + }; + + visit(tree, impl, visitor); + + return g; +} + +template +std::string + as_dot(Tree const &tree, + FullBinaryTreeImplementation const &impl, + std::function const &get_parent_label, + std::function const &get_leaf_label) { + + LabelledDataflowGraphView g = + as_labelled_dataflow_graph(tree, impl, get_parent_label, get_leaf_label); + + std::function get_node_label = + [](std::string const &s) { return s; }; + std::function get_input_label = + [](std::monostate const &) { return ""; }; + + return as_dot( + view_as_labelled_open_dataflow_graph(g), get_node_label, get_input_label); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h index 202058a3d1..f5bbbc228d 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_H +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.dtg.h" #include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" #include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" @@ -10,6 +11,17 @@ namespace FlexFlow { OpenDataflowSubgraphResult get_subgraph(OpenDataflowGraphView const &, std::unordered_set const &); +bidict + get_full_graph_values_to_subgraph_inputs( + OpenDataflowGraphView const &g, + std::unordered_set const &subgraph_nodes); + +OpenDataflowGraphData + get_subgraph_data(OpenDataflowGraphView const &g, + std::unordered_set const &subgraph_nodes, + bidict const + &full_graph_values_to_subgraph_inputs); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h index de48cd17e9..9b4ea6cd20 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h @@ -1,11 +1,13 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H +#include "utils/full_binary_tree/binary_tree_path.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include #include namespace FlexFlow { @@ -23,6 +25,10 @@ std::unordered_multiset get_leaves(BinarySPDecompositionTree const &); SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &); +std::optional + binary_sp_decomposition_tree_get_subtree_at_path( + BinarySPDecompositionTree const &, BinaryTreePath const &); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/as_dot.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/as_dot.h new file mode 100644 index 0000000000..9c999d8f6e --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/as_dot.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_AS_DOT_H + +#include "utils/full_binary_tree/as_dot.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +std::string as_dot( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + std::function const &get_series_label, + std::function const &get_parallel_label, + std::function const &get_leaf_label) { + FullBinaryTreeImplementation, Leaf> + full_binary_tree_impl = get_full_binary_impl_from_generic_sp_impl(impl); + + std::function const &)> + get_parent_label = + [&](std::variant const &parent) -> std::string { + return std::visit(overload{ + [&](Series const &series) -> std::string { + return get_series_label(series); + }, + [&](Parallel const ¶llel) -> std::string { + return get_parallel_label(parallel); + }, + }, + parent); + }; + + return as_dot(tree, full_binary_tree_impl, get_parent_label, get_leaf_label); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 377561d70c..8673264d36 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -32,6 +32,11 @@ T const &assert_unwrap(std::optional const &o) { return o.value(); } +template +T expect(std::optional const &x, std::string const &err) { + return unwrap(x, [&]() { throw mk_runtime_error(err); }); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/src/utils/full_binary_tree/as_dot.cc b/lib/utils/src/utils/full_binary_tree/as_dot.cc new file mode 100644 index 0000000000..12a1ab5533 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/as_dot.cc @@ -0,0 +1,16 @@ +#include "utils/full_binary_tree/as_dot.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template std::string + as_dot(Tree const &, + FullBinaryTreeImplementation const &, + std::function const &, + std::function const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc index ad3d4f26c0..36f027f792 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc @@ -4,7 +4,11 @@ #include "utils/containers/is_subseteq_of.h" #include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" +#include "utils/graph/dataflow_graph/dataflow_output_query.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_edges.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" #include "utils/graph/open_dataflow_graph/dataflow_graph_input_source.h" #include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" @@ -13,100 +17,89 @@ namespace FlexFlow { -struct OpenDataflowSubgraph final : public IOpenDataflowGraphView { - OpenDataflowSubgraph(OpenDataflowGraphView const &full_graph, - std::unordered_set const &subgraph_nodes, - bidict const - &full_graph_values_to_subgraph_inputs) - : full_graph(full_graph), subgraph_nodes(subgraph_nodes), - full_graph_values_to_subgraph_inputs( - full_graph_values_to_subgraph_inputs) { - assert(is_subseteq_of(this->subgraph_nodes, get_nodes(full_graph))); - } - - std::unordered_set query_nodes(NodeQuery const &q) const override { - return intersection(this->full_graph.query_nodes(q), this->subgraph_nodes); - } - - std::unordered_set - query_edges(OpenDataflowEdgeQuery const &q) const override { - std::unordered_set result; - for (OpenDataflowEdge const &open_e : this->full_graph.query_edges(q)) { - open_e.visit(overload{ - [&](DataflowEdge const &e) { - bool contains_src = contains(this->subgraph_nodes, e.src.node); - bool contains_dst = contains(this->subgraph_nodes, e.dst.node); - if (contains_src && contains_dst) { - result.insert(OpenDataflowEdge{e}); - } else if (contains_dst && !contains_src) { - result.insert(OpenDataflowEdge{DataflowInputEdge{ - this->full_graph_values_to_subgraph_inputs.at_l( - OpenDataflowValue{e.src}), - e.dst}}); - } - return std::nullopt; - }, - [&](DataflowInputEdge const &e) { - if (contains(this->subgraph_nodes, e.dst.node)) { - result.insert(OpenDataflowEdge{DataflowInputEdge{ - this->full_graph_values_to_subgraph_inputs.at_l( - OpenDataflowValue{e.src}), - e.dst}}); - } - return std::nullopt; - }}); - } - return result; - } - - std::unordered_set - query_outputs(DataflowOutputQuery const &q) const override { - return filter(this->full_graph.query_outputs(q), - [&](DataflowOutput const &o) { - return contains(this->subgraph_nodes, o.node); - }); - } - - std::unordered_set get_inputs() const override { - return unordered_set_of(values(this->full_graph_values_to_subgraph_inputs)); - }; - - OpenDataflowSubgraph *clone() const override { - return new OpenDataflowSubgraph{ - this->full_graph, - this->subgraph_nodes, - this->full_graph_values_to_subgraph_inputs, - }; - } - -private: - OpenDataflowGraphView full_graph; - std::unordered_set subgraph_nodes; - bidict - full_graph_values_to_subgraph_inputs; -}; - OpenDataflowSubgraphResult get_subgraph(OpenDataflowGraphView const &g, std::unordered_set const &subgraph_nodes) { - DataflowGraphInputSource input_source; bidict - full_graph_values_to_subgraph_inputs = generate_bidict( - get_subgraph_inputs(g, subgraph_nodes), - [&](OpenDataflowValue const &v) -> DataflowGraphInput { - return v.visit(overload{ - [](DataflowGraphInput const &i) { return i; }, - [&](DataflowOutput const &) { - return input_source.new_dataflow_graph_input(); - }, - }); - }); + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(g, subgraph_nodes); return OpenDataflowSubgraphResult{ - OpenDataflowGraphView::create( - g, subgraph_nodes, full_graph_values_to_subgraph_inputs), + OpenDataflowGraphView::create( + get_subgraph_data( + g, subgraph_nodes, full_graph_values_to_subgraph_inputs)), full_graph_values_to_subgraph_inputs, }; } +bidict + get_full_graph_values_to_subgraph_inputs( + OpenDataflowGraphView const &g, + std::unordered_set const &subgraph_nodes) { + DataflowGraphInputSource input_source; + return generate_bidict(get_subgraph_inputs(g, subgraph_nodes), + [&](OpenDataflowValue const &v) -> DataflowGraphInput { + return v.visit(overload{ + [](DataflowGraphInput const &i) { return i; }, + [&](DataflowOutput const &) { + return input_source.new_dataflow_graph_input(); + }, + }); + }); +} + +OpenDataflowGraphData + get_subgraph_data(OpenDataflowGraphView const &g, + std::unordered_set const &subgraph_nodes, + bidict const + &full_graph_values_to_subgraph_inputs) { + std::unordered_set subgraph_input_edges = + transform(get_subgraph_incoming_edges(g, subgraph_nodes), + [&](OpenDataflowEdge const &edge) { + return edge.visit( + overload{[&](DataflowInputEdge const &e) { + return OpenDataflowEdge{DataflowInputEdge{ + full_graph_values_to_subgraph_inputs.at_l( + OpenDataflowValue{e.src}), + e.dst}}; + }, + [&](DataflowEdge const &e) { + return OpenDataflowEdge{DataflowInputEdge{ + full_graph_values_to_subgraph_inputs.at_l( + OpenDataflowValue{e.src}), + e.dst}}; + }}); + }); + + OpenDataflowEdgeQuery subgraph_interior_edges_query = OpenDataflowEdgeQuery{ + DataflowInputEdgeQuery{ + query_set::match_none(), + query_set::match_none(), + query_set::match_none(), + }, + DataflowEdgeQuery{ + query_set{subgraph_nodes}, + query_set::matchall(), + query_set{subgraph_nodes}, + query_set::matchall(), + }, + }; + std::unordered_set subgraph_interior_edges = + g.query_edges(subgraph_interior_edges_query); + + std::unordered_set subgraph_inputs = + unordered_set_of(values(full_graph_values_to_subgraph_inputs)); + std::unordered_set subgraph_outputs = + filter(g.query_outputs(dataflow_output_query_all()), + [&](DataflowOutput const &o) { + return contains(subgraph_nodes, o.node); + }); + return OpenDataflowGraphData{ + subgraph_nodes, + set_union(subgraph_input_edges, subgraph_interior_edges), + subgraph_inputs, + subgraph_outputs, + }; +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index 62489ff75f..3e4bc13289 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -1,5 +1,6 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" @@ -82,4 +83,10 @@ SPDecompositionTreeNodeType }); } +std::optional + binary_sp_decomposition_tree_get_subtree_at_path( + BinarySPDecompositionTree const &tree, BinaryTreePath const &path) { + return get_subtree_at_path(tree, generic_impl_for_binary_sp_tree(), path); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/as_dot.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/as_dot.cc new file mode 100644 index 0000000000..f557515c83 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/as_dot.cc @@ -0,0 +1,21 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/as_dot.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template std::string + as_dot(Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + std::function const &, + std::function const &, + std::function const &); + +} // namespace FlexFlow diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/get_subgraph.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/get_subgraph.cc new file mode 100644 index 0000000000..c44e5f81b7 --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/get_subgraph.cc @@ -0,0 +1,349 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/containers/contains.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_full_graph_values_to_subgraph_inputs(OpenDataflowGraphView, " + "std::unordered_set) ") { + OpenDataflowGraph graph = + OpenDataflowGraph::create(); + + DataflowGraphInput i0 = graph.add_input(); + DataflowGraphInput i1 = graph.add_input(); + DataflowGraphInput i2 = graph.add_input(); + + NodeAddedResult n0_added = graph.add_node({OpenDataflowValue{i0}}, 1_n); + Node n0 = n0_added.node; + OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; + + NodeAddedResult n1_added = graph.add_node({v0, OpenDataflowValue{i1}}, 1_n); + Node n1 = n1_added.node; + OpenDataflowValue v1 = OpenDataflowValue{get_only(n1_added.outputs)}; + + NodeAddedResult n2_added = graph.add_node({v0}, 1_n); + Node n2 = n2_added.node; + OpenDataflowValue v2 = OpenDataflowValue{get_only(n2_added.outputs)}; + + NodeAddedResult n3_added = + graph.add_node({OpenDataflowValue{i2}, v1, v2}, 1_n); + Node n3 = n3_added.node; + + std::unordered_set subgraph_nodes = {n1, n2, n3}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + SUBCASE("left entries are correct") { + std::unordered_set correct = { + v0, OpenDataflowValue{i1}, OpenDataflowValue{i2}}; + CHECK(left_entries(full_graph_values_to_subgraph_inputs) == correct); + } + + SUBCASE("mapping is correct") { + CHECK(full_graph_values_to_subgraph_inputs.at_l(OpenDataflowValue{i1}) == + i1); + CHECK(full_graph_values_to_subgraph_inputs.at_l(OpenDataflowValue{i2}) == + i2); + std::unordered_set inputs = {i1, i2}; + CHECK(!contains(inputs, full_graph_values_to_subgraph_inputs.at_l(v0))); + } + } + + TEST_CASE( + "get_subgraph_data(OpenDataflowGraphView, std::unordered_set, " + "bidict)") { + SUBCASE("2-node graph without inputs") { + OpenDataflowGraph graph = + OpenDataflowGraph::create(); + + NodeAddedResult n0_added = graph.add_node({}, 1_n); + Node n0 = n0_added.node; + OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; + + NodeAddedResult n1_added = graph.add_node({v0}, 1_n); + Node n1 = n1_added.node; + + SUBCASE("subgraph is full graph") { + std::unordered_set subgraph_nodes = {n0, n1}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + OpenDataflowGraphData correct = OpenDataflowGraphData{ + subgraph_nodes, + {OpenDataflowEdge{ + DataflowEdge{DataflowOutput{n0, 0_n}, DataflowInput{n1, 0_n}}}}, + {}, + { + DataflowOutput{ + n0, + 0_n, + }, + DataflowOutput{ + n1, + 0_n, + }, + }}; + CHECK(result == correct); + } + + SUBCASE("subgraph is n0") { + std::unordered_set subgraph_nodes = {n0}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + OpenDataflowGraphData correct = OpenDataflowGraphData{subgraph_nodes, + {}, + {}, + {DataflowOutput{ + n0, + 0_n, + }}}; + CHECK(result == correct); + } + + SUBCASE("subgraph is n1") { + std::unordered_set subgraph_nodes = {n1}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + + DataflowGraphInput n0_as_subgraph_input = + full_graph_values_to_subgraph_inputs.at_l(v0); + + OpenDataflowGraphData correct = OpenDataflowGraphData{ + subgraph_nodes, + {OpenDataflowEdge{DataflowInputEdge{n0_as_subgraph_input, + DataflowInput{n1, 0_n}}}}, + {n0_as_subgraph_input}, + {DataflowOutput{ + n1, + 0_n, + }}}; + CHECK(result == correct); + } + + SUBCASE("subgraph is empty") { + std::unordered_set subgraph_nodes = {}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + OpenDataflowGraphData correct = + OpenDataflowGraphData{subgraph_nodes, {}, {}, {}}; + CHECK(result == correct); + } + } + + SUBCASE("3-node graph with inputs") { + OpenDataflowGraph graph = + OpenDataflowGraph::create(); + + DataflowGraphInput i0 = graph.add_input(); + DataflowGraphInput i1 = graph.add_input(); + + NodeAddedResult n0_added = graph.add_node({OpenDataflowValue{i0}}, 1_n); + Node n0 = n0_added.node; + OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; + + NodeAddedResult n1_added = + graph.add_node({v0, OpenDataflowValue{i1}}, 1_n); + Node n1 = n1_added.node; + + NodeAddedResult n2_added = graph.add_node({v0}, 1_n); + Node n2 = n2_added.node; + + SUBCASE("subgraph is full graph") { + std::unordered_set subgraph_nodes = {n0, n1, n2}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + + OpenDataflowGraphData correct = OpenDataflowGraphData{ + subgraph_nodes, + { + OpenDataflowEdge{DataflowInputEdge{i0, DataflowInput{n0, 0_n}}}, + OpenDataflowEdge{DataflowInputEdge{i1, DataflowInput{n1, 1_n}}}, + OpenDataflowEdge{DataflowEdge{DataflowOutput{n0, 0_n}, + DataflowInput{n1, 0_n}}}, + OpenDataflowEdge{{DataflowEdge{DataflowOutput{n0, 0_n}, + DataflowInput{n2, 0_n}}}}, + }, + {i0, i1}, + { + DataflowOutput{ + n0, + 0_n, + }, + DataflowOutput{ + n1, + 0_n, + }, + DataflowOutput{ + n2, + 0_n, + }, + }}; + CHECK(result == correct); + } + + SUBCASE("subgraph is (n0, n1) split") { + std::unordered_set subgraph_nodes = {n0, n1}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + + OpenDataflowGraphData correct = OpenDataflowGraphData{ + subgraph_nodes, + { + OpenDataflowEdge{DataflowInputEdge{i0, DataflowInput{n0, 0_n}}}, + OpenDataflowEdge{DataflowInputEdge{i1, DataflowInput{n1, 1_n}}}, + OpenDataflowEdge{DataflowEdge{DataflowOutput{n0, 0_n}, + DataflowInput{n1, 0_n}}}, + }, + {i0, i1}, + { + DataflowOutput{ + n0, + 0_n, + }, + DataflowOutput{ + n1, + 0_n, + }, + }}; + CHECK(result == correct); + } + + SUBCASE("subgraph is (n0, n1) split") { + std::unordered_set subgraph_nodes = {n0, n1}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + + OpenDataflowGraphData correct = OpenDataflowGraphData{ + subgraph_nodes, + { + OpenDataflowEdge{DataflowInputEdge{i0, DataflowInput{n0, 0_n}}}, + OpenDataflowEdge{DataflowInputEdge{i1, DataflowInput{n1, 1_n}}}, + OpenDataflowEdge{DataflowEdge{DataflowOutput{n0, 0_n}, + DataflowInput{n1, 0_n}}}, + }, + {i0, i1}, + { + DataflowOutput{ + n0, + 0_n, + }, + DataflowOutput{ + n1, + 0_n, + }, + }}; + CHECK(result == correct); + } + + SUBCASE("subgraph is (n0, n2) split") { + std::unordered_set subgraph_nodes = {n0, n2}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + + OpenDataflowGraphData correct = OpenDataflowGraphData{ + subgraph_nodes, + { + OpenDataflowEdge{DataflowInputEdge{i0, DataflowInput{n0, 0_n}}}, + OpenDataflowEdge{DataflowEdge{DataflowOutput{n0, 0_n}, + DataflowInput{n2, 0_n}}}, + }, + {i0}, + { + DataflowOutput{ + n0, + 0_n, + }, + DataflowOutput{ + n2, + 0_n, + }, + }}; + CHECK(result == correct); + } + + SUBCASE("subgraph is (n1, n2) split") { + std::unordered_set subgraph_nodes = {n1, n2}; + + bidict + full_graph_values_to_subgraph_inputs = + get_full_graph_values_to_subgraph_inputs(graph, subgraph_nodes); + + OpenDataflowGraphData result = get_subgraph_data( + graph, subgraph_nodes, full_graph_values_to_subgraph_inputs); + + DataflowGraphInput n0_as_subgraph_input = + full_graph_values_to_subgraph_inputs.at_l(OpenDataflowValue{v0}); + + OpenDataflowGraphData correct = OpenDataflowGraphData{ + subgraph_nodes, + { + OpenDataflowEdge{DataflowInputEdge{i1, DataflowInput{n1, 1_n}}}, + OpenDataflowEdge{DataflowInputEdge{n0_as_subgraph_input, + DataflowInput{n1, 0_n}}}, + OpenDataflowEdge{DataflowInputEdge{n0_as_subgraph_input, + DataflowInput{n2, 0_n}}}, + }, + {i1, n0_as_subgraph_input}, + { + DataflowOutput{ + n1, + 0_n, + }, + DataflowOutput{ + n2, + 0_n, + }, + }}; + CHECK(result == correct); + } + } + } +}