Skip to content

Conversation

@NuojCheng
Copy link
Collaborator

@NuojCheng NuojCheng commented Dec 27, 2025

Description

This PR introduces fractional batch size search to the MaxText estimator and refactors the internal policy management logic to be more robust and maintainable. These changes allow the estimator to find optimal batch sizes using non-batch-based sharding (e.g., tensor parallelism, context parallelism) that results in fractional per-device batch sizes.

Key Changes

1. Estimator Refactoring

  • Structured Policy Management: Introduced the RematPolicy class and Action enum (REMAT, OFFLOAD, DEVICE) to replace dictionary-based policy manipulation.
  • Policy Transitions: Implemented next_policy() and previous_policy() methods within RematPolicy to handle transitions between memory-efficient and compute-efficient states more cleanly.

2. Fractional Batch Size Support

  • Normalization Constant: Added find_pdb_scalar() to calculate a multiplier based on non-data/non-FSDP mesh axes.
  • Updated Binary Search: Modified largest_batch_size to use the pdb_scalar during its search, ensuring compatibility with memory and compute estimation logic while finding fractional boundaries.

3. Configuration & Output

  • Configurable Output: Added the write_estimator_result flag to control whether the estimator writes its findings to a separate file (remat_commands_from_estimator.txt).
  • Default Behavior: Updated build_argv to ensure decoder_layer_input=device is appended by default if not otherwise specified.

4. Testing

  • New Test Suite: Added tests/estimator_test.py providing comprehensive coverage for the new RematPolicy logic, pdb_scalar calculation, and the binary search boundary finding.

Impact

  • Higher Accuracy: The estimator can now find the absolute maximum batch size for complex topologies where per-device batch sizes are not integers.
  • Better Code Quality: Moving to a class-based policy representation makes the search algorithm easier to reason about and extend.

FIXES: b/460511589

Tests

Test 1: search both batch sizes and remat policies (llama3.1-405b on tpu7x-1024)

Command

python -m MaxText.estimator \
MaxText/configs/base.yml \
compile_topology=tpu7x-1024 \
compile_topology_num_slices=1 \
model_name=llama3.1-405b \
max_target_length=32768 \
ici_context_parallelism=8 \
ici_fsdp_parallelism=-1 \
log_config=False \
write_estimator_result=False

Output recommendation:

Search completed in 647.53 seconds.
  - Found valid combo: pdb=0.25, policy={'mlpwo': 'remat', 'context': 'remat', 'mlpwi_0': 'remat', 'mlpwi_1': 'remat', 'query_proj': 'remat', 'out_proj': 'remat', 'key_proj': 'remat', 'value_proj': 'remat'}
  - Found valid combo: pdb=0.25, policy={'mlpwo': 'remat', 'context': 'remat', 'mlpwi_0': 'remat', 'mlpwi_1': 'remat', 'query_proj': 'remat', 'out_proj': 'remat', 'key_proj': 'remat', 'value_proj': 'offload'}
  - Found valid combo: pdb=0.25, policy={'mlpwo': 'remat', 'context': 'remat', 'mlpwi_0': 'remat', 'mlpwi_1': 'remat', 'query_proj': 'remat', 'out_proj': 'remat', 'key_proj': 'offload', 'value_proj': 'offload'}
  - Found valid combo: pdb=0.25, policy={'mlpwo': 'remat', 'context': 'remat', 'mlpwi_0': 'remat', 'mlpwi_1': 'remat', 'query_proj': 'remat', 'out_proj': 'offload', 'key_proj': 'offload', 'value_proj': 'offload'}
  - Found valid combo: pdb=0.25, policy={'mlpwo': 'remat', 'context': 'remat', 'mlpwi_0': 'remat', 'mlpwi_1': 'remat', 'query_proj': 'offload', 'out_proj': 'offload', 'key_proj': 'offload', 'value_proj': 'offload'}
  - Found valid combo: pdb=0.125, policy={'mlpwo': 'remat', 'context': 'remat', 'mlpwi_0': 'remat', 'mlpwi_1': 'offload', 'query_proj': 'offload', 'out_proj': 'offload', 'key_proj': 'offload', 'value_proj': 'offload'}
  - Found valid combo: pdb=0.125, policy={'mlpwo': 'remat', 'context': 'remat', 'mlpwi_0': 'offload', 'mlpwi_1': 'offload', 'query_proj': 'offload', 'out_proj': 'offload', 'key_proj': 'offload', 'value_proj': 'offload'}
  - Found valid combo: pdb=0.125, policy={'mlpwo': 'remat', 'context': 'offload', 'mlpwi_0': 'offload', 'mlpwi_1': 'offload', 'query_proj': 'offload', 'out_proj': 'offload', 'key_proj': 'offload', 'value_proj': 'offload'}
  - Found valid combo: pdb=0.125, policy={'mlpwo': 'offload', 'context': 'offload', 'mlpwi_0': 'offload', 'mlpwi_1': 'offload', 'query_proj': 'offload', 'out_proj': 'offload', 'key_proj': 'offload', 'value_proj': 'offload'}
  - Found valid combo: pdb=0.125, policy={'mlpwo': 'offload', 'context': 'offload', 'mlpwi_0': 'offload', 'mlpwi_1': 'offload', 'query_proj': 'offload', 'out_proj': 'offload', 'key_proj': 'offload', 'value_proj': 'device'}
  - Found valid combo: pdb=0.125, policy={'mlpwo': 'offload', 'context': 'offload', 'mlpwi_0': 'offload', 'mlpwi_1': 'offload', 'query_proj': 'offload', 'out_proj': 'offload', 'key_proj': 'device', 'value_proj': 'device'}
  - Found valid combo: pdb=0.125, policy={'mlpwo': 'offload', 'context': 'offload', 'mlpwi_0': 'offload', 'mlpwi_1': 'offload', 'query_proj': 'offload', 'out_proj': 'device', 'key_proj': 'device', 'value_proj': 'device'}
Done.

Full output: log

Test 2: search best remat policies given fixed batch size (deepseek3-671b on v5p-1024)

python3 -m MaxText.estimator MaxText/configs/base.yml \
model_name=deepseek3-671b \
compile_topology=v5p-1024 \
compile_topology_num_slices=1 \
ici_fsdp_parallelism=512 \
per_device_batch_size=2.0 \
dtype=bfloat16 \
weight_dtype=float32 \
scan_layers=True \
sparse_matmul=True \
use_custom_sort_vjp=False \
use_tokamax_splash=True \
use_tokamax_gmm=True \
sa_use_fused_bwd_kernel=False \
max_target_length=8192 \
sa_block_q=1024 \
sa_block_kv=1024 \
sa_block_kv_compute=1024 \
sa_block_q_dkv=1024 \
sa_block_kv_dkv=1024 \
sa_block_kv_dkv_compute=1024 \
sa_block_q_dq=1024 \
sa_block_kv_dq=1024 \
log_config=False \
write_estimator_result=False \
decoder_layer_input=offload

Output recommendation:

Search completed in 468.94 seconds.
  - Found valid combo: pdb=2.0, policy={'mlpwo': 'remat', 'out_proj': 'remat', 'context': 'remat', 'mlpwi_0': 'remat', 'mlpwi_1': 'remat', 'query_proj': 'offload', 'key_proj': 'offload', 'value_proj': 'offload'}
Done.

Full output: log

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@NuojCheng NuojCheng added the draft Draft PR label Dec 27, 2025
@codecov
Copy link

codecov bot commented Dec 27, 2025

Codecov Report

❌ Patch coverage is 0% with 85 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/estimator.py 0.00% 85 Missing ⚠️

📢 Thoughts on this report? Let us know!

@NuojCheng NuojCheng force-pushed the chengnuojin-estimator branch from 45a8a31 to afb6a3c Compare December 30, 2025 23:30
@github-actions
Copy link

🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This Pull Request introduces fractional batch size search to the MaxText estimator and refactors the internal policy management logic using a RematPolicy class. The changes significantly improve the estimator's capability to find optimal batch sizes, especially for non-batch-based sharding configurations, leading to higher accuracy and better code quality.

🔍 General Feedback

  • The introduction of the RematPolicy class and Action enum is a positive change, enhancing the maintainability and readability of the policy management logic.
  • The new find_pdb_scalar and the updated largest_batch_size correctly implement fractional batch size support, which is a valuable feature.
  • The new test suite in tests/estimator_test.py provides comprehensive coverage for the new and modified functionalities, ensuring the quality of the changes.

return {key: "device" for key in tensor_names}
def find_pdb_scalar(config):
"""Calculates the scaling factor to normalize the Per-Device Batch (PDB) size.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The largest_batch_size function is type-hinted to return an int, but with the introduction of fractional batch sizes and pdb_scalar, it now returns float values. Please update the type hint to float.

Suggested change
def largest_batch_size(base_argv, policy, min_pdb=None, max_pdb=32.0, pdb_scalar=1.0) -> float:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant