diff --git a/docs/reference/core_concepts/moe_configuration.md b/docs/reference/core_concepts/moe_configuration.md index 668cb1aa4..65f772009 100644 --- a/docs/reference/core_concepts/moe_configuration.md +++ b/docs/reference/core_concepts/moe_configuration.md @@ -50,7 +50,7 @@ Dropping: `first_num_dense_layers`: The number of initial dense layers before the first MoE layer is introduced. -`float32_weight_sum`: If enabled, performs the summation of expert weights using float32 precision for improved numerical stability. +`float32_weight_sum`: If enabled, performs the summation of expert weights using float32 precision for improved numerical stability. Recommended specifically when lower precision types cause convergence or quality issues. ### Routing Mechanism `use_random_routing`: If enabled, ignores the gate logits and routes tokens to random experts. This is designed to simulate load balancing for debugging and performance testing purposes. @@ -82,11 +82,11 @@ Dropping: * Value > 0: Enforces a strict capacity limit; tokens exceeding this limit are dropped. * Value = -1: Dropless with dense matrix multiplication, which is computationally expensive and typically used only as a baseline. -`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul. +`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul. Recommended to replace the inefficient scatter-add generated by the `jax.numpy.take` in the backward pass. -`mlp_bias`: If enabled, add bias terms within the expert MLP layers. +`mlp_bias`: If enabled, add learnable bias terms for MLP matmul. Originally implemented to support the GPT-OSS model architecture. -`use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications. +`use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications that yields performance benefits. ## 2. Sharding `expert_shard_attention_option`: Determines how the "expert" axis is interpreted when sharding attention layers. Options include: @@ -95,9 +95,9 @@ Dropping: `use_ring_of_experts` (experimental): This feature requires expert parallelism. If enabled, it replaces the standard two All-to-All communications with All-Gather in dispatch and Reduce-Scatter in collect. By gathering inputs across all shards, it allows for local routing and Top-K calculations, followed by result aggregation via Reduce-Scatter. This approach is particularly effective for models with a large Top-K, as it gathers activations before they are replicated k times to reduce communication. -`moe_fsdp_use_two_stage_all_gather`: If enabled, splits the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable. +`moe_fsdp_use_two_stage_all_gather`: If enabled, split the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable. -`fsdp_shard_on_exp`: If enabled, shard MLP weights on expert dimension instead of embedding dimension during FSDP sharding. +`fsdp_shard_on_exp`: If enabled, shard the expert dimension of the MLP weights on the FSDP axis, and recommended when num_experts is a multiple of fsdp_parallelism. ## 3. Performance Tuning These parameters provide granular control over the tiling dimensions for sparse matmul Pallas kernel. diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 942aa24ba..079e1aaed 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -157,7 +157,7 @@ logits_dot_in_fp32: false # whether to use fp32 in logits_dense or shared_embed cast_logits_to_fp32: true # whether to cast the logits to fp32. the higher precision is generally beneficial, but it can vary slightly. float32_qk_product: false # in dot_product attention, whether to cast to fp32 the inputs to qk product float32_logits: false # in dot_product attention, whether to cast to fp32 the inputs to softmax -float32_weight_sum: true # whether to use full fp32 precision for weight_sum during final unpermute in moe +float32_weight_sum: true # whether to use full fp32 precision to sum expert weights for numerical stability # multi-token prediction configs # the number of auxiliary prediction layers to use for mtp. @@ -179,7 +179,7 @@ sparse_matmul: true capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default load_balance_loss_weight: 0.0 # weight for the load balance loss use_random_routing: false # whether to use random routing for debug/test purpose -use_custom_sort_vjp: true # whether to use a custom sort vjp for sparse matmul ops +use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism # tunable tiling dimensions used for mlp gmm # megablox/jax ragged dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`) @@ -212,7 +212,8 @@ expert_shard_attention_option: "fsdp" # when moe weight matrices are sharded on both fsdp and fsdp-transpose axes, use two separate all-gather calls moe_fsdp_use_two_stage_all_gather: false -# shard the moe weights on num_expert_dim. this can be performanct when num_expert % fdsp_parallisum +# Shard the expert dimension of the MLP weights on the FSDP axis. +# This configuration is recommended only when num_experts is a multiple of fsdp_parallelism fsdp_shard_on_exp: False # use fsdp and fsdp_transpose axes for sharding the moe weights use_2d_fsdp_sharding: False @@ -225,13 +226,12 @@ routed_scaling_factor: 1.0 # scaling factor for routing scores routed_score_func: "" # scoring function for routing routed_bias: False # a flag if a learnable bias is added for routing routed_bias_update_rate: 0.0 # a flag indicate the update rate applied to the router bias term -mlp_bias: False # a flag if a learnable bias is added for MLP matmul +mlp_bias: False # a flag if a learnable bias is added for MLP matmul, and originally implemented to support the GPT-OSS model architecture. n_routing_groups: -1 # number of groups for routing, disabled by default topk_routing_group: -1 # number of top groups to route inputs. For EP, # Splits the batch to allow for better scheduling when using expert parallelism by overlapping the # all-to-all communication with compute. Currently only implemented with DeepSeek sparse layers. -use_batch_split_schedule: False # whether to use batch split schedule -# sending activations to a maximum of topk_routing_group distinct devices can yield performance benefits. +use_batch_split_schedule: False # a flag if splitting batch into micro-batches to hide communications that yields performance benefits. # For complex architectures like llama4 there are repeated sets of # inhomogeneous layers. E.g. maverick uses [dense+rope, moe+rope, dense+rope, moe+nope] diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index f0280c31d..01b478cf0 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -553,7 +553,9 @@ class MoEGeneral(BaseModel): num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.") capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.") load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.") - use_custom_sort_vjp: bool = Field(True, description="Whether to use a custom sort VJP for sparse matmul ops.") + use_custom_sort_vjp: bool = Field( + True, description="Whether to use a custom VJP sort for efficient backward pass processing in sparse matmul." + ) use_ring_of_experts: bool = Field( False, description="Whether to use Ring of Experts for sparse matmul expert parallelism.", @@ -570,8 +572,8 @@ class MoEGeneral(BaseModel): ) fsdp_shard_on_exp: bool = Field( False, - description="Shard the MoE weights on the num_expert dimension. Can be performant when " - "num_experts % fsdp_parallelism != 0.", + description="Shard the expert dimension of the MLP weights on the FSDP axis, " + "and recommended when num_experts is a multiple of fsdp_parallelism", ) use_2d_fsdp_sharding: bool = Field( False, @@ -583,7 +585,7 @@ class MoEGeneral(BaseModel): ) float32_weight_sum: bool = Field( True, - description="Whether to use full fp32 precision for weight_sum during final unpermute in MoE.", + description="Whether to use full fp32 precision to sum expert weights for numerical stability.", ) @@ -640,13 +642,16 @@ class DeepSeekMoE(BaseModel): routed_score_func: str = Field("", description="Scoring function for routing (e.g., 'softmax', 'sigmoid').") routed_bias: bool = Field(False, description="Whether to add a bias term for routing.") routed_bias_update_rate: float = Field(0.0, description="Update rate applied to the router bias term.") - mlp_bias: bool = Field(False, description="Whether to add a learnable bias for MLP matmul.") + mlp_bias: bool = Field( + False, + description="Whether to add a learnable bias for MLP matmul, " + "and originally implemented to support the GPT-OSS model architecture", + ) n_routing_groups: int = Field(-1, description="Number of groups for routing, disabled by default.") topk_routing_group: int = Field(-1, description="Number of top groups to route inputs to.") use_batch_split_schedule: bool = Field( False, - description="Splits the batch to allow for better scheduling when using expert parallelism by overlapping all-to-all " - "with compute.", + description="Whether to split batch into micro-batches to hide communications that yields performance benefits.", )