From f5f32ee2350c16954d5e55301a3f1574a67f535f Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 5 Dec 2025 06:02:36 -0300 Subject: [PATCH 1/7] wip: sharding poc --- exla/c_src/exla/exla.cc | 86 ++++++++++++++++++++++++++++++++++ exla/c_src/exla/exla_mlir.h | 2 + exla/lib/exla/defn.ex | 39 +++++++++++++++ exla/lib/exla/mlir/function.ex | 20 ++++++++ exla/lib/exla/mlir/module.ex | 8 ++++ exla/lib/exla/nif.ex | 5 ++ exla/lib/exla/sharding.ex | 75 +++++++++++++++++++++++++++++ 7 files changed, 235 insertions(+) create mode 100644 exla/lib/exla/sharding.ex diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 9019c8ee85..27610fccb8 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -14,6 +14,8 @@ #include "exla_nif_util.h" #include "ipc.h" #include "mlir/IR/MLIRContext.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/register.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "xla/pjrt/pjrt_api.h" @@ -67,6 +69,9 @@ mlir_new_context(ErlNifEnv *env, context->getOrLoadDialect(); context->getOrLoadDialect(); context->getOrLoadDialect(); + context->getOrLoadDialect(); + mlir::sdy::registerAllDialects( + const_cast(context->getDialectRegistry())); return context; } @@ -171,6 +176,87 @@ fine::Ok<> mlir_pop_region(ErlNifEnv *env, FINE_NIF(mlir_pop_region, 0); +fine::Ok<> mlir_add_mesh(ErlNifEnv *env, fine::ResourcePtr module, + std::string mesh_name, + std::vector> axes) { + auto builder = module->builder(); + auto context = module->module()->getContext(); + + llvm::SmallVector axis_attrs; + for (auto [name, size] : axes) { + axis_attrs.push_back(mlir::sdy::MeshAxisAttr::get(context, name, size)); + } + + auto mesh_attr = mlir::sdy::MeshAttr::get(context, axis_attrs); + + // Create the mesh op at the beginning of the module + auto module_op = module->module(); + auto &body_region = module_op.getBodyRegion(); + mlir::OpBuilder::InsertionGuard guard(*builder); + builder->setInsertionPointToStart(&body_region.front()); + + mlir::OperationState state(builder->getUnknownLoc(), "sdy.mesh"); + mlir::sdy::MeshOp::build(*builder, state, mesh_name, mesh_attr); + builder->create(state); + + return fine::Ok(); +} + +FINE_NIF(mlir_add_mesh, 0); + +mlir::sdy::TensorShardingAttr mlir_create_tensor_sharding_attr( + mlir::MLIRContext *context, std::string mesh_name, + std::vector> dim_shardings) { + llvm::SmallVector dim_sharding_attrs; + for (const auto &dim : dim_shardings) { + llvm::SmallVector axis_refs; + for (const auto &axis : dim) { + axis_refs.push_back(mlir::sdy::AxisRefAttr::get(context, axis)); + } + dim_sharding_attrs.push_back(mlir::sdy::DimensionShardingAttr::get( + context, axis_refs, /*is_closed=*/false, /*priority=*/0)); + } + + return mlir::sdy::TensorShardingAttr::get( + context, mesh_name, dim_sharding_attrs, + /*replicated_axes=*/llvm::ArrayRef(), + /*unreduced_axes=*/llvm::ArrayRef()); +} + +fine::Ok<> +mlir_set_arg_sharding(ErlNifEnv *env, fine::ResourcePtr function, + int64_t arg_index, std::string mesh_name, + std::vector> dim_shardings) { + + auto context = function->module()->module()->getContext(); + auto sharding_attr = + mlir_create_tensor_sharding_attr(context, mesh_name, dim_shardings); + + function->function().setArgAttr(arg_index, "sdy.sharding", sharding_attr); + + return fine::Ok(); +} + +FINE_NIF(mlir_set_arg_sharding, 0); + +fine::Ok<> +mlir_set_result_sharding(ErlNifEnv *env, + fine::ResourcePtr function, + int64_t result_index, std::string mesh_name, + std::vector> dim_shardings) { + + auto context = function->module()->module()->getContext(); + auto sharding_attr = + mlir_create_tensor_sharding_attr(context, mesh_name, dim_shardings); + + function->function().setResultAttr(result_index, "sdy.sharding", + sharding_attr); + + return fine::Ok(); +} + +FINE_NIF(mlir_set_result_sharding, 0); + mlir::Type mlir_get_typespec(ErlNifEnv *env, fine::ResourcePtr value) { return value->getType(); diff --git a/exla/c_src/exla/exla_mlir.h b/exla/c_src/exla/exla_mlir.h index 095ad4c1a7..5d5ef060ee 100644 --- a/exla/c_src/exla/exla_mlir.h +++ b/exla/c_src/exla/exla_mlir.h @@ -29,6 +29,8 @@ class MLIRFunction { llvm::MutableArrayRef GetArguments() { return func_->getBody().front().getArguments(); } + mlir::func::FuncOp function() { return *func_; } + fine::ResourcePtr module() { return module_; } private: diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 275528fa28..5cb03a037f 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -228,6 +228,9 @@ defmodule EXLA.Defn do outfeed = Outfeed.new(hooks, defined_hooks) comp_key = {ref, client.name, outfeed.used_hooks, lazy_transfers, options} + mesh = Keyword.get(options, :mesh) + input_shardings = Keyword.get(options, :input_shardings, []) + {comp_time, {evaled, {xla_time, executable, inputs_and_typespecs, outfeed}}} = :timer.tc(fn -> comp_cache_fun.(comp_key, fn -> @@ -254,6 +257,29 @@ defmodule EXLA.Defn do end) EXLA.MLIR.Module.new(comp_typespecs, out_typespecs, fn builder -> + # Add device mesh to module if provided + if mesh do + EXLA.MLIR.Module.add_mesh(builder.module, mesh) + end + + if !mesh and input_shardings != [] do + raise ArgumentError, "input sharding configs provided but no device mesh was provided" + end + + # Apply sharding annotations to function arguments if provided + if input_shardings != [] do + num_comp_args = length(comp_typespecs) + + if length(input_shardings) != num_comp_args do + raise ArgumentError, + "expected #{num_comp_args} input sharding configs (one per argument), got #{length(input_shardings)}" + end + + Enum.with_index(input_shardings, fn sharding, arg_index -> + Function.set_arg_sharding(builder, arg_index, sharding) + end) + end + # Only create the token when we know it will actually be # used, that is: streaming, lazy transfers or hooks outfeed = @@ -270,6 +296,19 @@ defmodule EXLA.Defn do options = Keyword.put(options, :callback_server_pid, callback_server_pid) + # Compute num_partitions from mesh and enable SPMD if mesh is provided + options = + if mesh do + num_partitions = + Enum.reduce(mesh.axes, 1, fn {_name, size}, acc -> acc * size end) + + options + |> Keyword.put(:num_partitions, num_partitions) + |> Keyword.put(:use_spmd, true) + else + options + end + {xla_time, executable} = :timer.tc(fn -> EXLA.MLIR.Module.compile( diff --git a/exla/lib/exla/mlir/function.ex b/exla/lib/exla/mlir/function.ex index 7b1157955a..9b1e120b48 100644 --- a/exla/lib/exla/mlir/function.ex +++ b/exla/lib/exla/mlir/function.ex @@ -36,4 +36,24 @@ defmodule EXLA.MLIR.Function do def pop_region(%Function{ref: ref}) do EXLA.NIF.mlir_pop_region(ref) end + + @doc """ + Sets sharding annotation for a function argument. + """ + def set_arg_sharding(%Function{ref: ref}, arg_index, %EXLA.Sharding.TensorSharding{ + mesh_name: mesh, + axes: dims + }) do + EXLA.NIF.mlir_set_arg_sharding(ref, arg_index, mesh, dims) + end + + @doc """ + Sets sharding annotation for a function result. + """ + def set_result_sharding(%Function{ref: ref}, result_index, %EXLA.Sharding.TensorSharding{ + mesh_name: mesh, + axes: dims + }) do + EXLA.NIF.mlir_set_result_sharding(ref, result_index, mesh, dims) + end end diff --git a/exla/lib/exla/mlir/module.ex b/exla/lib/exla/mlir/module.ex index d1ba3d0b0b..f8be0f771c 100644 --- a/exla/lib/exla/mlir/module.ex +++ b/exla/lib/exla/mlir/module.ex @@ -134,6 +134,14 @@ defmodule EXLA.MLIR.Module do } end + @doc """ + Adds a device mesh definition to the module. + """ + def add_mesh(%__MODULE__{ref: module_ref}, %EXLA.Sharding.DeviceMesh{name: name, axes: axes}) do + EXLA.NIF.mlir_add_mesh(module_ref, name, axes) + :ok + end + @doc """ Returns a human-readable representation of the module using MLIR syntax. diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 6e70d07d57..281d5c3896 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -28,6 +28,9 @@ defmodule EXLA.NIF do def mlir_op(_function, _op_name, _operands, _result_type, _attributes, _blocks), do: err!() def mlir_push_region(_function, _arg_types), do: err!() def mlir_pop_region(_function), do: err!() + def mlir_add_mesh(_module, _mesh_name, _axes), do: err!() + def mlir_set_arg_sharding(_function, _arg_index, _mesh_name, _dim_shardings), do: err!() + def mlir_set_result_sharding(_function, _result_index, _mesh_name, _dim_shardings), do: err!() def mlir_build(_function, _root), do: err!() def mlir_compile( @@ -79,6 +82,8 @@ defmodule EXLA.NIF do def reset_peak_memory(_client), do: err!() def get_per_device_memory(_client), do: err!() + def ensure_shardy_included(), do: err!() + # Elixir callback bridge def start_runtime_callback_bridge(_dispatcher_pid), do: err!() def clear_runtime_callback_bridge(_dispatcher_pid), do: err!() diff --git a/exla/lib/exla/sharding.ex b/exla/lib/exla/sharding.ex new file mode 100644 index 0000000000..4293a580e7 --- /dev/null +++ b/exla/lib/exla/sharding.ex @@ -0,0 +1,75 @@ +defmodule EXLA.Sharding do + @moduledoc """ + Helper module for defining Shardy device meshes and tensor sharding specifications. + """ + + defmodule DeviceMesh do + @moduledoc """ + Represents a device mesh configuration. + """ + @enforce_keys [:name, :axes] + defstruct [:name, :axes] + + @type axis :: {name :: String.t(), size :: pos_integer()} + @type t :: %__MODULE__{ + name: String.t(), + axes: [axis()] + } + end + + defmodule TensorSharding do + @moduledoc """ + Represents a sharding specification for a tensor. + """ + @enforce_keys [:mesh_name, :axes] + defstruct [:mesh_name, :axes] + + @type dim_sharding :: [String.t()] + @type t :: %__MODULE__{ + mesh_name: String.t(), + axes: [dim_sharding()] + } + end + + @doc """ + Creates a device mesh definition. + + ## Examples + + iex> EXLA.Sharding.mesh(:my_mesh, x: 2, y: 4) + %EXLA.Sharding.DeviceMesh{name: "my_mesh", axes: [{"x", 2}, {"y", 4}]} + """ + def mesh(name, axes) when (is_atom(name) or is_binary(name)) and is_list(axes) do + normalized_axes = + Enum.map(axes, fn {k, v} -> {to_string(k), v} end) + + %DeviceMesh{name: to_string(name), axes: normalized_axes} + end + + @doc """ + Creates a sharding specification for a tensor. + + The `dim_shardings` list must match the rank of the tensor. + Each element is a list of axis names that the corresponding dimension is sharded on. + + ## Examples + + # Rank 2 tensor, dim 0 sharded on "x", dim 1 sharded on "y" + iex> EXLA.Sharding.sharding(:my_mesh, [["x"], ["y"]]) + %EXLA.Sharding.TensorSharding{mesh_name: "my_mesh", axes: [["x"], ["y"]]} + + # Rank 2 tensor, dim 0 sharded on "x", dim 1 replicated + iex> EXLA.Sharding.sharding(:my_mesh, [["x"], []]) + %EXLA.Sharding.TensorSharding{mesh_name: "my_mesh", axes: [["x"], []]} + """ + def sharding(mesh_name, dim_shardings) do + %TensorSharding{mesh_name: to_string(mesh_name), axes: dim_shardings} + end + + @doc """ + Creates a fully replicated sharding specification (empty list for all dims). + """ + def replicated(mesh_name, rank) do + %TensorSharding{mesh_name: to_string(mesh_name), axes: List.duplicate([], rank)} + end +end From 9d12202a2cc9778f0a58b5fe063176b7c7bb6103 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 5 Dec 2025 06:51:15 -0300 Subject: [PATCH 2/7] minimal working sharding example --- exla/c_src/exla/exla_client.cc | 80 ++++++++++++++++++++++++---------- exla/lib/exla/defn.ex | 55 ++++++++++++++++++++++- exla/lib/exla/executable.ex | 21 +++++++-- exla/lib/exla/mlir/module.ex | 4 +- exla/sharding.exs | 10 +++++ 5 files changed, 141 insertions(+), 29 deletions(-) create mode 100644 exla/sharding.exs diff --git a/exla/c_src/exla/exla_client.cc b/exla/c_src/exla/exla_client.cc index c2f5d0fd8d..43cb75650c 100644 --- a/exla/c_src/exla/exla_client.cc +++ b/exla/c_src/exla/exla_client.cc @@ -152,15 +152,24 @@ PjRtBufferFromBinary(xla::PjRtClient *client, ERL_NIF_TERM source_term, tsl::StatusOr>> UnpackRunArguments( ErlNifEnv *env, ExlaExecutable::RunArguments arguments, std::vector> &transient_buffers, - ExlaClient *client, xla::DeviceAssignment device_assignment, - int device_id) { + ExlaClient *client, xla::DeviceAssignment device_assignment, int device_id, + int num_partitions) { std::vector> arg_buffers; arg_buffers.reserve(arguments.size()); - int replica = 0; + int index = 0; for (const auto &replica_arguments : arguments) { - auto device = device_id >= 0 ? device_id : device_assignment(replica, 0); + // For automatic SPMD: each input list goes to a different partition device + // device_assignment is (replica, partition) -> device + // With num_partitions > 1, we iterate through partitions (replica=0, + // partition=0..N-1) For replication, we iterate through replicas + // (replica=0..N-1, partition=0) + int replica = (num_partitions > 1) ? 0 : index; + int partition = (num_partitions > 1) ? index : 0; + + auto device = + device_id >= 0 ? device_id : device_assignment(replica, partition); auto replica_buffers = std::vector(); replica_buffers.reserve(replica_arguments.size()); @@ -200,7 +209,7 @@ tsl::StatusOr>> UnpackRunArguments( arg_buffers.push_back(std::move(replica_buffers)); - replica++; + index++; } return arg_buffers; @@ -216,7 +225,17 @@ UnpackResult(ErlNifEnv *env, for (int i = 0; i < result.size(); i++) { auto replica_results = std::vector>(); - int64_t device = device_id >= 0 ? device_id : device_assignment(i, 0); + + int64_t device; + if (device_id >= 0) { + device = device_id; + } else if (device_assignment.computation_count() > 1) { + // SPMD: results correspond to partitions (replica 0, partition i) + device = device_assignment(0, i); + } else { + // Replication: results correspond to replicas (replica i, partition 0) + device = device_assignment(i, 0); + } for (auto &pjrt_buf : result.at(i)) { pjrt_buf->GetReadyFuture().Await(); @@ -266,20 +285,23 @@ ExlaExecutable::Run(ErlNifEnv *env, ExlaExecutable::RunArguments arguments, // a pmap, but in all other cases it will be equal to 1 int num_replicas = executable_->num_replicas(); + // the number of partitions is used for SPMD partitioning + int num_partitions = executable_->num_partitions(); + // input buffers are a list of lists, where each list maps to the args // to pass to one of the replicas in a computation, e.g. [replica_args1, // replica_args2, ...] std::vector> input_buffers; // the device assignment is a 2d array which maps coordinates (replica, - // partition) to a device; or in this case just maps a replica to a device + // partition) to a device xla::DeviceAssignment device_assignment; if (client_->client()->platform_name() == "METAL") { device_assignment = xla::DeviceAssignment(1, 1); } else { - EXLA_ASSIGN_OR_RETURN( - device_assignment, - client_->client()->GetDefaultDeviceAssignment(num_replicas, 1)); + EXLA_ASSIGN_OR_RETURN(device_assignment, + client_->client()->GetDefaultDeviceAssignment( + num_replicas, num_partitions)); } // Buffers allocated from binaries for this specific run need to be @@ -300,15 +322,20 @@ ExlaExecutable::Run(ErlNifEnv *env, ExlaExecutable::RunArguments arguments, EXLA_ASSIGN_OR_RETURN(input_buffers, UnpackRunArguments(env, arguments, transient_buffers, client_, device_assignment, - device_id)); + device_id, num_partitions)); } - // at this point input buffers is a vector of arguments per replica - // and the size of that vector should equal the number of replicas in the - // executable, otherwise it is invalid - if (num_replicas != input_buffers.size()) { - return xla::InvalidArgument("Got %d replica arguments for %d replicas", - input_buffers.size(), num_replicas); + // at this point input buffers is a vector of arguments per device + // For automatic SPMD: one input list per partition (num_partitions lists) + // For standard replication: one input list per replica (num_replicas lists) + // Each input list contains full unreplicated tensors; XLA slices based on + // sharding + int expected_lists = num_partitions > 1 ? num_partitions : num_replicas; + if (input_buffers.size() != expected_lists) { + return xla::InvalidArgument("Got %d argument lists, expected %d " + "(num_replicas=%d, num_partitions=%d)", + input_buffers.size(), expected_lists, + num_replicas, num_partitions); } std::vector>> @@ -333,10 +360,9 @@ ExlaExecutable::Run(ErlNifEnv *env, ExlaExecutable::RunArguments arguments, // result buffers to unpack per_replica_results.push_back(std::move(portable_result)); } else { - // no device ID is present, so it may be a replicated executable which means - // we need to use the replica execution path - // TODO: This now exposes a `returned_futures` API, does this make sense for - // us? + // no device ID is present, so it may be a replicated or SPMD executable + // For SPMD with num_partitions > 1, Execute handles partitioned execution + // using sharding annotations EXLA_ASSIGN_OR_RETURN(per_replica_results, executable_->Execute(input_buffers, options)); } @@ -344,9 +370,15 @@ ExlaExecutable::Run(ErlNifEnv *env, ExlaExecutable::RunArguments arguments, // EXLA_ASSIGN_OR_RETURN(per_replica_results, // executable_->Execute(input_buffers, options)); - // sanity check - if (per_replica_results.size() != num_replicas) { - return xla::FailedPrecondition("Invalid execution."); + // sanity check - for SPMD we get results per partition, for replication per + // replica + int expected_results = num_partitions > 1 ? num_partitions : num_replicas; + if (per_replica_results.size() != expected_results) { + return xla::FailedPrecondition( + "Invalid execution: got %d results, expected %d (num_replicas=%d, " + "num_partitions=%d)", + per_replica_results.size(), expected_results, num_replicas, + num_partitions); } // we need to unpack the results into Erlang terms, the result is a vector diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 5cb03a037f..339df354ca 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -149,10 +149,19 @@ defmodule EXLA.Defn do EXLA.Defn.Buffers.from_nx!(arg, executable) end) - EXLA.Executable.run(executable, [buffers], run_options) + input_lists = slice_inputs(buffers, executable) + + EXLA.Executable.run(executable, input_lists, run_options) else [result] -> [EXLA.Defn.Buffers.to_nx!(result, outputs)] + + results when is_list(results) -> + # For SPMD, we get multiple results (one per partition). + # For now, we just take the first one to verify execution. + # TODO: Implement re-assembly of sharded outputs + [first | _] = results + [EXLA.Defn.Buffers.to_nx!(first, outputs)] after EXLA.Defn.Lock.unlock(lock) end @@ -160,6 +169,50 @@ defmodule EXLA.Defn do defp run_key(%{client: %{ref: ref}, device_id: device_id}), do: [ref | device_id] + defp slice_inputs(buffers, %EXLA.Executable{num_partitions: 1}), do: [buffers] + + defp slice_inputs( + buffers, + %EXLA.Executable{mesh: _mesh, input_shardings: _shardings, num_partitions: np} + ) + when np > 1 do + # TODO: Implement generic slicing based on mesh and input_shardings. + # Currently hardcoded for 2x2 mesh testing. + if np == 4 and length(buffers) == 2 do + [%{data: data0, typespec: type0}, %{data: data1, typespec: type1}] = buffers + + s0_0 = binary_part(data0, 0, 4) + s0_1 = binary_part(data0, 4, 4) + s0_2 = binary_part(data0, 8, 4) + s0_3 = binary_part(data0, 12, 4) + + s1_0 = binary_part(data1, 0, 4) + s1_1 = binary_part(data1, 0, 4) + s1_2 = binary_part(data1, 4, 4) + s1_3 = binary_part(data1, 4, 4) + + t0 = %{type0 | shape: {1, 1}} + t1 = %{type1 | shape: {1, 1}} + + wrap = fn data, type -> + %EXLA.BinaryBuffer{data: data, typespec: type} + end + + [ + [wrap.(s0_0, t0), wrap.(s1_0, t1)], + [wrap.(s0_1, t0), wrap.(s1_1, t1)], + [wrap.(s0_2, t0), wrap.(s1_2, t1)], + [wrap.(s0_3, t0), wrap.(s1_3, t1)] + ] + else + # Fallback for unsupported cases + List.duplicate(buffers, np) + end + end + + defp slice_inputs(buffers, %EXLA.Executable{num_partitions: np}), + do: List.duplicate(buffers, np) + ## Compile defp compile( diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index 15ffbbdfe0..c49b74ed13 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -7,7 +7,16 @@ defmodule EXLA.Executable do alias EXLA.{BinaryBuffer, DeviceBuffer} @enforce_keys [:client, :ref, :output_typespecs, :num_replicas, :num_partitions, :device_id] - defstruct [:client, :ref, :output_typespecs, :num_replicas, :num_partitions, :device_id] + defstruct [ + :client, + :ref, + :output_typespecs, + :num_replicas, + :num_partitions, + :device_id, + :mesh, + :input_shardings + ] @doc """ Runs the given executable with a list of lists as inputs and the given options. @@ -45,7 +54,9 @@ defmodule EXLA.Executable do output_typespecs: output_typespecs, num_replicas: num_replicas, num_partitions: num_partitions, - device_id: device_id + device_id: device_id, + mesh: mesh, + input_shardings: input_shardings }) when node(ref) == node() do serialized_exec = @@ -58,7 +69,9 @@ defmodule EXLA.Executable do output_typespecs: output_typespecs, num_replicas: num_replicas, num_partitions: num_partitions, - device_id: device_id + device_id: device_id, + mesh: mesh, + input_shardings: input_shardings } end @@ -85,6 +98,8 @@ defmodule EXLA.Executable do num_replicas: num_replicas, num_partitions: num_partitions, device_id: device_id, + mesh: Map.get(data, :mesh), + input_shardings: Map.get(data, :input_shardings), ref: ref, client: client } diff --git a/exla/lib/exla/mlir/module.ex b/exla/lib/exla/mlir/module.ex index f8be0f771c..fceb607e7e 100644 --- a/exla/lib/exla/mlir/module.ex +++ b/exla/lib/exla/mlir/module.ex @@ -130,7 +130,9 @@ defmodule EXLA.MLIR.Module do output_typespecs: return_typespecs, num_replicas: num_replicas, num_partitions: num_partitions, - device_id: device_id + device_id: device_id, + mesh: Keyword.get(options, :mesh), + input_shardings: Keyword.get(options, :input_shardings) } end diff --git a/exla/sharding.exs b/exla/sharding.exs new file mode 100644 index 0000000000..df60fa597e --- /dev/null +++ b/exla/sharding.exs @@ -0,0 +1,10 @@ +fun = fn x, y -> {Nx.add(x, y), Nx.multiply(x, y)} end +args = [Nx.iota({2, 2}), Nx.iota({2, 1})] + +mesh = EXLA.Sharding.mesh("mesh", x: 2, y: 2) + +input_shardings = [EXLA.Sharding.sharding("mesh", [["x"], ["y"]]), EXLA.Sharding.sharding("mesh", [["x"], ["y"]])] + +result = EXLA.jit_apply(fun, args, mesh: mesh, input_shardings: input_shardings) + +dbg(result) From 2322dd673b52e5a233830cdb036b799e787605f0 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 5 Dec 2025 07:14:18 -0300 Subject: [PATCH 3/7] feat: mesh-based slicing --- exla/lib/exla/defn.ex | 149 ++++++++++++++++++++++++++++++++---------- exla/sharding.exs | 11 ++-- 2 files changed, 123 insertions(+), 37 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 339df354ca..6ca4dfb43c 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -173,46 +173,129 @@ defmodule EXLA.Defn do defp slice_inputs( buffers, - %EXLA.Executable{mesh: _mesh, input_shardings: _shardings, num_partitions: np} + %EXLA.Executable{ + mesh: mesh, + input_shardings: shardings, + num_partitions: np + } ) - when np > 1 do - # TODO: Implement generic slicing based on mesh and input_shardings. - # Currently hardcoded for 2x2 mesh testing. - if np == 4 and length(buffers) == 2 do - [%{data: data0, typespec: type0}, %{data: data1, typespec: type1}] = buffers - - s0_0 = binary_part(data0, 0, 4) - s0_1 = binary_part(data0, 4, 4) - s0_2 = binary_part(data0, 8, 4) - s0_3 = binary_part(data0, 12, 4) - - s1_0 = binary_part(data1, 0, 4) - s1_1 = binary_part(data1, 0, 4) - s1_2 = binary_part(data1, 4, 4) - s1_3 = binary_part(data1, 4, 4) - - t0 = %{type0 | shape: {1, 1}} - t1 = %{type1 | shape: {1, 1}} - - wrap = fn data, type -> - %EXLA.BinaryBuffer{data: data, typespec: type} - end - - [ - [wrap.(s0_0, t0), wrap.(s1_0, t1)], - [wrap.(s0_1, t0), wrap.(s1_1, t1)], - [wrap.(s0_2, t0), wrap.(s1_2, t1)], - [wrap.(s0_3, t0), wrap.(s1_3, t1)] - ] - else - # Fallback for unsupported cases - List.duplicate(buffers, np) + when np > 1 and not is_nil(mesh) and not is_nil(shardings) do + # Build mesh axis map for quick lookup + mesh_axes = Map.new(mesh.axes) + + # Generate shards for each partition + for partition_idx <- 0..(np - 1) do + # Convert linear partition index to mesh coordinates + coords = unravel_index(partition_idx, mesh.axes) + + # Slice each buffer according to its sharding spec + Enum.zip(buffers, shardings) + |> Enum.map(fn {buffer, sharding} -> + slice_buffer_for_partition(buffer, sharding, coords, mesh_axes) + end) end end defp slice_inputs(buffers, %EXLA.Executable{num_partitions: np}), do: List.duplicate(buffers, np) + # Converts linear partition index to mesh coordinates + # Example: index 3 in [x: 2, y: 2] -> %{x: 1, y: 1} + defp unravel_index(index, axes) do + {coords, _} = + Enum.reduce(Enum.reverse(axes), {%{}, index}, fn {name, size}, {acc, current_idx} -> + coord = rem(current_idx, size) + remaining = div(current_idx, size) + {Map.put(acc, name, coord), remaining} + end) + + coords + end + + # Slices a single buffer for a specific partition based on sharding spec + defp slice_buffer_for_partition( + %EXLA.BinaryBuffer{data: data, typespec: typespec}, + sharding, + coords, + mesh_axes + ) do + # Convert binary buffer to Nx tensor + tensor = binary_buffer_to_nx(data, typespec) + + # Slice along each dimension according to sharding spec + sharded_tensor = + tensor.shape + |> Tuple.to_list() + |> Enum.with_index() + |> Enum.reduce(tensor, fn {dim_size, dim_idx}, acc -> + axis_names = Enum.at(sharding.axes, dim_idx, []) + + if axis_names == [] do + # Dimension is replicated, keep full dimension + acc + else + # Special case: size 1 dimensions cannot be sharded + # Treat them as replicated (effectively remove sharding) + if dim_size == 1 do + acc + else + # Calculate total number of shards for this dimension + # (product of all mesh axes this dimension is sharded on) + shards_count = + Enum.reduce(axis_names, 1, fn name, acc -> + acc * Map.fetch!(mesh_axes, name) + end) + + # Error if dimension size is less than shards_count (and not size 1) + if dim_size < shards_count do + raise ArgumentError, + "Cannot shard dimension #{dim_idx} of size #{dim_size} across #{shards_count} shards. " <> + "Dimension size must be >= shards_count (or size 1 for implicit replication)" + end + + # Calculate chunk size (assuming even division) + chunk_size = div(dim_size, shards_count) + + # Calculate slice index for this partition + slice_idx = + case axis_names do + [name] -> + Map.fetch!(coords, name) + + _ -> + # Multi-axis sharding: calculate linear index from coordinates + # This handles the cartesian product of mesh axes + Enum.reduce(axis_names, 0, fn name, acc -> + coord = Map.fetch!(coords, name) + axis_size = Map.fetch!(mesh_axes, name) + acc * axis_size + coord + end) + end + + # Normal case: evenly divisible + start = slice_idx * chunk_size + Nx.slice_along_axis(acc, start, chunk_size, axis: dim_idx) + end + end + end) + + # Convert back to BinaryBuffer + nx_to_binary_buffer(sharded_tensor) + end + + # Converts BinaryBuffer to Nx tensor + defp binary_buffer_to_nx(data, %EXLA.Typespec{type: type, shape: shape}) do + Nx.from_binary(data, type) |> Nx.reshape(shape) + end + + # Converts Nx tensor to BinaryBuffer + defp nx_to_binary_buffer(tensor) do + %EXLA.BinaryBuffer{ + data: Nx.to_binary(tensor), + typespec: %EXLA.Typespec{type: tensor.type, shape: tensor.shape} + } + end + ## Compile defp compile( diff --git a/exla/sharding.exs b/exla/sharding.exs index df60fa597e..7ae64d2b54 100644 --- a/exla/sharding.exs +++ b/exla/sharding.exs @@ -1,10 +1,13 @@ fun = fn x, y -> {Nx.add(x, y), Nx.multiply(x, y)} end -args = [Nx.iota({2, 2}), Nx.iota({2, 1})] +args = [Nx.iota({8, 2}), Nx.iota({8, 1})] -mesh = EXLA.Sharding.mesh("mesh", x: 2, y: 2) +mesh = EXLA.Sharding.mesh("mesh", x: 4, y: 2) -input_shardings = [EXLA.Sharding.sharding("mesh", [["x"], ["y"]]), EXLA.Sharding.sharding("mesh", [["x"], ["y"]])] +input_shardings = [EXLA.Sharding.sharding("mesh", [["x"], ["y"]]), EXLA.Sharding.sharding("mesh", [["x"], []])] -result = EXLA.jit_apply(fun, args, mesh: mesh, input_shardings: input_shardings) +result = EXLA.to_mlir_module(fun, args, mesh: mesh, input_shardings: input_shardings) + +IO.puts(result.mlir_module) +result = EXLA.jit_apply(fun, args, mesh: mesh, input_shardings: input_shardings) dbg(result) From abeabbda6e07a7a7ff4913a9d2674a92008c7d0f Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 5 Dec 2025 07:18:59 -0300 Subject: [PATCH 4/7] wip --- exla/sharding.exs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/exla/sharding.exs b/exla/sharding.exs index 7ae64d2b54..a43aeecda4 100644 --- a/exla/sharding.exs +++ b/exla/sharding.exs @@ -1,9 +1,9 @@ fun = fn x, y -> {Nx.add(x, y), Nx.multiply(x, y)} end args = [Nx.iota({8, 2}), Nx.iota({8, 1})] -mesh = EXLA.Sharding.mesh("mesh", x: 4, y: 2) +mesh = EXLA.Sharding.mesh("mesh", x: 2, y: 2, z: 2) -input_shardings = [EXLA.Sharding.sharding("mesh", [["x"], ["y"]]), EXLA.Sharding.sharding("mesh", [["x"], []])] +input_shardings = [EXLA.Sharding.sharding("mesh", [["x", "z"], ["y"]]), EXLA.Sharding.sharding("mesh", [["x", "z"], []])] result = EXLA.to_mlir_module(fun, args, mesh: mesh, input_shardings: input_shardings) @@ -11,3 +11,5 @@ IO.puts(result.mlir_module) result = EXLA.jit_apply(fun, args, mesh: mesh, input_shardings: input_shardings) dbg(result) + +# run with: XLA_FLAGS="--xla_dump_to=/tmp/xla_dump --xla_dump_hlo_as_text --xla_force_host_platform_device_count=10" mix run sharding.exs From 6a35121191480903a96a6d4242ae1acb865c841f Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 5 Dec 2025 07:29:41 -0300 Subject: [PATCH 5/7] chore: remove set_result_sharding --- exla/c_src/exla/exla.cc | 18 ------------------ exla/lib/exla/mlir/function.ex | 10 ---------- exla/lib/exla/nif.ex | 1 - 3 files changed, 29 deletions(-) diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 27610fccb8..f6efca3b04 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -239,24 +239,6 @@ mlir_set_arg_sharding(ErlNifEnv *env, fine::ResourcePtr function, FINE_NIF(mlir_set_arg_sharding, 0); -fine::Ok<> -mlir_set_result_sharding(ErlNifEnv *env, - fine::ResourcePtr function, - int64_t result_index, std::string mesh_name, - std::vector> dim_shardings) { - - auto context = function->module()->module()->getContext(); - auto sharding_attr = - mlir_create_tensor_sharding_attr(context, mesh_name, dim_shardings); - - function->function().setResultAttr(result_index, "sdy.sharding", - sharding_attr); - - return fine::Ok(); -} - -FINE_NIF(mlir_set_result_sharding, 0); - mlir::Type mlir_get_typespec(ErlNifEnv *env, fine::ResourcePtr value) { return value->getType(); diff --git a/exla/lib/exla/mlir/function.ex b/exla/lib/exla/mlir/function.ex index 9b1e120b48..80d59decc9 100644 --- a/exla/lib/exla/mlir/function.ex +++ b/exla/lib/exla/mlir/function.ex @@ -46,14 +46,4 @@ defmodule EXLA.MLIR.Function do }) do EXLA.NIF.mlir_set_arg_sharding(ref, arg_index, mesh, dims) end - - @doc """ - Sets sharding annotation for a function result. - """ - def set_result_sharding(%Function{ref: ref}, result_index, %EXLA.Sharding.TensorSharding{ - mesh_name: mesh, - axes: dims - }) do - EXLA.NIF.mlir_set_result_sharding(ref, result_index, mesh, dims) - end end diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 281d5c3896..28b93bca19 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -30,7 +30,6 @@ defmodule EXLA.NIF do def mlir_pop_region(_function), do: err!() def mlir_add_mesh(_module, _mesh_name, _axes), do: err!() def mlir_set_arg_sharding(_function, _arg_index, _mesh_name, _dim_shardings), do: err!() - def mlir_set_result_sharding(_function, _result_index, _mesh_name, _dim_shardings), do: err!() def mlir_build(_function, _root), do: err!() def mlir_compile( From 475120131f647c744b7ab11c9714a1e69d001167 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 5 Dec 2025 07:30:09 -0300 Subject: [PATCH 6/7] chore: remove ensure_shardy_included --- exla/lib/exla/nif.ex | 2 -- 1 file changed, 2 deletions(-) diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 28b93bca19..85d051f835 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -81,8 +81,6 @@ defmodule EXLA.NIF do def reset_peak_memory(_client), do: err!() def get_per_device_memory(_client), do: err!() - def ensure_shardy_included(), do: err!() - # Elixir callback bridge def start_runtime_callback_bridge(_dispatcher_pid), do: err!() def clear_runtime_callback_bridge(_dispatcher_pid), do: err!() From e0c9cd57ecc772551e6dec386b5ca4a9fe4395c5 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 5 Dec 2025 07:30:44 -0300 Subject: [PATCH 7/7] chore: remove redundant func --- exla/lib/exla/sharding.ex | 7 ------- 1 file changed, 7 deletions(-) diff --git a/exla/lib/exla/sharding.ex b/exla/lib/exla/sharding.ex index 4293a580e7..281ae530bf 100644 --- a/exla/lib/exla/sharding.ex +++ b/exla/lib/exla/sharding.ex @@ -65,11 +65,4 @@ defmodule EXLA.Sharding do def sharding(mesh_name, dim_shardings) do %TensorSharding{mesh_name: to_string(mesh_name), axes: dim_shardings} end - - @doc """ - Creates a fully replicated sharding specification (empty list for all dims). - """ - def replicated(mesh_name, rank) do - %TensorSharding{mesh_name: to_string(mesh_name), axes: List.duplicate([], rank)} - end end