-
Notifications
You must be signed in to change notification settings - Fork 444
Add fractional batch size search for MaxText estimator #2891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
45a8a31 to
afb6a3c
Compare
afb6a3c to
8a28531
Compare
|
🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This Pull Request introduces fractional batch size search to the MaxText estimator and refactors the internal policy management logic using a RematPolicy class. The changes significantly improve the estimator's capability to find optimal batch sizes, especially for non-batch-based sharding configurations, leading to higher accuracy and better code quality.
🔍 General Feedback
- The introduction of the
RematPolicyclass andActionenum is a positive change, enhancing the maintainability and readability of the policy management logic. - The new
find_pdb_scalarand the updatedlargest_batch_sizecorrectly implement fractional batch size support, which is a valuable feature. - The new test suite in
tests/estimator_test.pyprovides comprehensive coverage for the new and modified functionalities, ensuring the quality of the changes.
| return {key: "device" for key in tensor_names} | ||
| def find_pdb_scalar(config): | ||
| """Calculates the scaling factor to normalize the Per-Device Batch (PDB) size. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 The largest_batch_size function is type-hinted to return an int, but with the introduction of fractional batch sizes and pdb_scalar, it now returns float values. Please update the type hint to float.
| def largest_batch_size(base_argv, policy, min_pdb=None, max_pdb=32.0, pdb_scalar=1.0) -> float: |
Description
This PR introduces fractional batch size search to the MaxText estimator and refactors the internal policy management logic to be more robust and maintainable. These changes allow the estimator to find optimal batch sizes using non-batch-based sharding (e.g., tensor parallelism, context parallelism) that results in fractional per-device batch sizes.
Key Changes
1. Estimator Refactoring
RematPolicyclass andActionenum (REMAT,OFFLOAD,DEVICE) to replace dictionary-based policy manipulation.next_policy()andprevious_policy()methods withinRematPolicyto handle transitions between memory-efficient and compute-efficient states more cleanly.2. Fractional Batch Size Support
find_pdb_scalar()to calculate a multiplier based on non-data/non-FSDP mesh axes.largest_batch_sizeto use thepdb_scalarduring its search, ensuring compatibility with memory and compute estimation logic while finding fractional boundaries.3. Configuration & Output
write_estimator_resultflag to control whether the estimator writes its findings to a separate file (remat_commands_from_estimator.txt).build_argvto ensuredecoder_layer_input=deviceis appended by default if not otherwise specified.4. Testing
tests/estimator_test.pyproviding comprehensive coverage for the newRematPolicylogic,pdb_scalarcalculation, and the binary search boundary finding.Impact
FIXES: b/460511589
Tests
Test 1: search both batch sizes and remat policies (llama3.1-405b on tpu7x-1024)
Command
Output recommendation:
Full output: log
Test 2: search best remat policies given fixed batch size (deepseek3-671b on v5p-1024)
Output recommendation:
Full output: log
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.