-
Notifications
You must be signed in to change notification settings - Fork 444
Checkpoint utility: gpt-oss, hf to orbax #2818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
dae0b48 to
dd51cf4
Compare
|
🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This pull request introduces a significant and valuable set of enhancements to the checkpoint conversion utilities. The refactoring improves code modularity by moving shared functions to a central utils.py, and the new functionality, such as support for local Hugging Face models and unscanned GPT-OSS models, greatly increases the flexibility of these tools.
🔍 General Feedback
- Positive: The addition of more granular timing logs is a great improvement for performance analysis and debugging. The updated documentation in the README provides a much clearer overview of supported models and conversion paths.
- Good Refactoring: Moving
check_param_map_keystoutils.pyand introducingget_maxtext_model_infointo_maxtext.pyare excellent changes that improve code organization and reusability.
I have left a couple of minor comments, one regarding a bug in the timing calculation and another for a small docstring clarification. Overall, this is a solid contribution.
hengtaoguo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great work!
| } | ||
| wi_0_1 = input_tensor | ||
| wi_0 = wi_0_1[..., ::2] | ||
| wi_1 = wi_0_1[..., 1::2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a smart way to handle interleaving layers!
| return local_dir | ||
|
|
||
|
|
||
| def check_param_map_keys(param_map_keys, maxtext_state_keys): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add a quick note on when this function is used and which model it applies to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the function name and docstring.
- Overall: Preprocess maxtext keys for transformation in to_maxtext.py and to_huggingface.py
- Validates param_mapping coverage, so we can check error before transformation
- Filters unused keys in param_mapping (where some mt keys might be grouped), so we can choose exact mt keys to transform.
| mt_target_shape_final = abstract_leaf_value.shape | ||
| if not use_lazy_load and config.scan_layers: | ||
| max_logging.log(f"maxtext param: {mt_param_key}") | ||
| is_many_mt_key = isinstance(mt_param_key, tuple) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it mean later on the conversion are categorized into two main types: 1-to-1 mapping (most common case) and 1-to-N mapping, based on this flag?
A very naive question and I don't have any thoughts yet: do you see possibility in N-to-N mapping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be specific, we have 8 combinations: maxtext_key in [single, many - tuple] x hf_key in [single - unscan, list - scan, list - unscan expert, nested list - scan expert]
1 maxtext_key_single x hf_key_sinlge-unscan: {m: h}
2 maxtext_key_single x hf_key_list-scan: {m: [h_l1, h_l2, ...h_ln]}
3 maxtext_key_single x hf_key_list-unscan_expert: {m: [e1, e2, ..., ek] }
4 maxtext_key_single x hf_key_nested_list-scan_expert: {m: [[e1_l1, ....e1_ln], ..., [em_l1, ..., ek_ln] }
5 maxtext_key_many x hf_key_sinlge-unscan: {(m1, ..., mp): h}
6 maxtext_key_many x hf_key_list-scan: {(m1, ..., mp): [h_l1, h_l2, ...h_ln]}
7 maxtext_key_many x hf_key_list-unscan_expert: {(m1, ..., mp): [e1, e2, ..., ek] }
8 maxtext_key_many x hf_key_nested_list-scan_expert: {(m1, ..., mp): [[e1_l1, ....e1_ln], ..., [em_l1, ..., ek_ln] }
In summary:
- 1 & 2: most common
- 3 & 4: present for some MoE model. For instance, qwen3-moe, deepseek, mixtral, N expert is stored as N hf tensor. (On the other had, some does not need this; gpt-oss has N expert store as a single combined hf tensor)
- 5 & 6: present for some models, where hf key is split into several equal-shape maxtext key (e.g., gpt-oss, and llama4 which hasn't been onboarded)
- 7 & 8: I don't expect this to happen. I haven't tested this. (probably can work?)
Unfortunately, this has become quite convoluted. Probably I should add more comments?
cc @RissyRan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation! Let's create a bug and track the refactor. This will be confusing. Let's refactor each case into a small helper function.
If you could wrap your added codes into a helper function, it will be a good start!
ff78257 to
c5f0176
Compare
| - Gemma3 multimodal (4B, 12B, 27B). | ||
| - Qwen3 (0.6B, 4B, 8B, 14B, 32B). | ||
| - Mixtral (8x7B, 8x22B). | ||
| | Model Family | Sizes | HF $\to$ Orbax (scan) | HF $\to$ Orbax (unscan) | Orbax (scan) $\to$ HF | Orbax (unscan) $\to$ HF | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this table!
| def _initialize_index(self): | ||
| """Fetches and parses the Hugging Face model index file to build a shard map.""" | ||
| files = list_repo_files(self.model_id, token=self.token) | ||
| if self.is_local: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add this option in the README.md?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added --hf_model_path in readme.
| load_fn = partial(_loader, tensor_getter, hf_source_keys_or_key, mt_target_shape_final, hook_fn) | ||
| load_fn = partial(_loader, tensor_getter, hf_source_keys_or_key, mt_target_shape, hook_fn) | ||
| # Stacked mapping | ||
| elif isinstance(hf_source_keys_or_key[0], list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was really confused when onboarding Mixtral models. Do you think we could add a decision tree or something similar at the top?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added more comments for these conditions.
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! If you could wrap your added codes into helper functions among that 8 cases, which will be really appreciated :) Also, let's create a bug to track the rest refactor.
| mt_target_shape_final = abstract_leaf_value.shape | ||
| if not use_lazy_load and config.scan_layers: | ||
| max_logging.log(f"maxtext param: {mt_param_key}") | ||
| is_many_mt_key = isinstance(mt_param_key, tuple) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation! Let's create a bug and track the refactor. This will be confusing. Let's refactor each case into a small helper function.
If you could wrap your added codes into a helper function, it will be a good start!
Description
Previously we have gpt-oss, orbax(scan) -> hf: #2647
Fix: b/459541579
Fix: b/452392132
Fix: b/452391921
What this does
to_maxtext.pyparam_mapping.py, gpt-ossto_huggingface.py_check_param_map_keysto utils.py, so it can be reused by to_maxtext.pyTests
1 HF -> orbax (gpt-oss-20b)
since we made non-trivial changes to lazy tensor implementation, also test lazy mode
HF -> orbax (scan), cpu
https://paste.googleplex.com/4888544332087296
3.56 min
https://paste.googleplex.com/5272274628378624
HF -> orbax (scan), cpu, lazy load
https://paste.googleplex.com/6192468888518656
4.96 min
https://paste.googleplex.com/5022180226236416
HF -> orbax (unscan), cpu
https://paste.googleplex.com/6000687559344128
https://paste.googleplex.com/6128993298415616
2 orbax -> HF (gpt-oss-20b)
orbax -> HF (unscan), cpu
https://paste.googleplex.com/5483624130543616
https://paste.googleplex.com/4854884102963200
3 HF -> orbax (check other models just in case)
qwen3-4b
https://paste.googleplex.com/5401590255190016
https://paste.googleplex.com/6538164284030976
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.