Skip to content

Conversation

@shuningjin
Copy link
Collaborator

@shuningjin shuningjin commented Dec 11, 2025

Description

Previously we have gpt-oss, orbax(scan) -> hf: #2647

Fix: b/459541579

  • gpt-oss, hf -> orbax(scan)
  • gpt-oss, hf -> orbax(unscan)
  • gpt-oss, orbax(unscan) -> hf

Fix: b/452392132

  • implement weight splitting for hf->orbax (i.e., hf to many maxtext key)

Fix: b/452391921

  • verify interleaved scan pattern for hf->orbax

What this does

to_maxtext.py

  • allow hf to many mt
    • assume mt keys have same shape, hook function return a tensor stacked in last dim
    • accomodate lazy tensor: unoptimized, hf is repeated loaded for each mt
  • allow loading local hf checkpoint
    • the remote hf checkpoint is quantized for some models (e.g., gpt-oss - mxfp4, deepseek - fp8), yet we are using local de-quantized hf version (usually bf16) for conversion
    • accomodate lazy tensor
  • improve condition for single axis stack
  • (other: factor out get_maxtext_dict, add time)

param_mapping.py, gpt-oss

  • implement interleave function for hf to many mt
  • add unscan version of mapping and hook

to_huggingface.py

  • move _check_param_map_keys to utils.py, so it can be reused by to_maxtext.py

Tests

1 HF -> orbax (gpt-oss-20b)

since we made non-trivial changes to lazy tensor implementation, also test lazy mode

HF -> orbax (scan), cpu

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=gpt-oss-20b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2

https://paste.googleplex.com/4888544332087296

3.56 min

CKPT=gs://runner-maxtext-logs/2025-12-26-21-51
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=$CKPT/0/items \
scan_layers=true \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2 \
tokenizer_path=openai/gpt-oss-20b tokenizer_type=huggingface \
skip_jax_distributed_system=True

https://paste.googleplex.com/5272274628378624

HF -> orbax (scan), cpu, lazy load

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=gpt-oss-20b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2 \
--lazy_load_tensors=true

https://paste.googleplex.com/6192468888518656

4.96 min

CKPT=gs://runner-maxtext-logs/2025-12-26-21-58
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=$CKPT/0/items \
scan_layers=true \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2 \
tokenizer_path=openai/gpt-oss-20b tokenizer_type=huggingface \
skip_jax_distributed_system=True

https://paste.googleplex.com/5022180226236416

HF -> orbax (unscan), cpu

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=gpt-oss-20b scan_layers=false \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2

https://paste.googleplex.com/6000687559344128

CKPT=gs://runner-maxtext-logs/2025-12-26-22-20
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=$CKPT/0/items \
scan_layers=false \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2 \
tokenizer_path=openai/gpt-oss-20b tokenizer_type=huggingface \
skip_jax_distributed_system=True

https://paste.googleplex.com/6128993298415616

2 orbax -> HF (gpt-oss-20b)

orbax -> HF (unscan), cpu

ID=$(date +%Y-%m-%d-%H-%M-%S); \
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=gs://shuningjin-multipod-dev/gpt-oss-20b/unscan-bf16-v2-2025-09-02-01-16-00/0/items \
base_output_directory=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-hf-$ID \
scan_layers=false \
attention=dot_product skip_jax_distributed_system=True \
weight_dtype=bfloat16 checkpoint_storage_concurrent_gb=1024

https://paste.googleplex.com/5483624130543616

HF_PATH=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-hf-2025-12-26-22-56-40
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=gs://shuningjin-multipod-dev/gpt-oss-20b/unscan-bf16-v2-2025-09-02-01-16-00/0/items \
scan_layers=false \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=$HF_PATH \
tokenizer_path=openai/gpt-oss-20b tokenizer_type=huggingface \
skip_jax_distributed_system=True

https://paste.googleplex.com/4854884102963200

3 HF -> orbax (check other models just in case)

qwen3-4b

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=qwen3-4b scan_layers=false \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True

https://paste.googleplex.com/5401590255190016

CKPT=gs://runner-maxtext-logs/2025-12-26-23-25
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=qwen3-4b attention=dot_product \
override_model_config=true enable_dropout=false tokenizer_type=huggingface \
load_parameters_path=$CKPT/0/items scan_layers=false \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=8 \
tokenizer_path=Qwen/Qwen3-4B --run_hf_model=True --hf_model_path=Qwen/Qwen3-4B \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
skip_jax_distributed_system=True

https://paste.googleplex.com/6538164284030976

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@shuningjin shuningjin changed the title [WIP] checkpoint util: gpt-oss, orbax -> hf [WIP] checkpoint util: gpt-oss, hf -> orbax Dec 11, 2025
@codecov
Copy link

codecov bot commented Dec 23, 2025

@shuningjin shuningjin changed the title [WIP] checkpoint util: gpt-oss, hf -> orbax Checkpoint utility: gpt-oss, hf to orbax Dec 26, 2025
@shuningjin shuningjin marked this pull request as ready for review December 26, 2025 23:45
@github-actions
Copy link

🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request introduces a significant and valuable set of enhancements to the checkpoint conversion utilities. The refactoring improves code modularity by moving shared functions to a central utils.py, and the new functionality, such as support for local Hugging Face models and unscanned GPT-OSS models, greatly increases the flexibility of these tools.

🔍 General Feedback

  • Positive: The addition of more granular timing logs is a great improvement for performance analysis and debugging. The updated documentation in the README provides a much clearer overview of supported models and conversion paths.
  • Good Refactoring: Moving check_param_map_keys to utils.py and introducing get_maxtext_model_info in to_maxtext.py are excellent changes that improve code organization and reusability.

I have left a couple of minor comments, one regarding a bug in the timing calculation and another for a small docstring clarification. Overall, this is a solid contribution.

Copy link
Collaborator

@hengtaoguo hengtaoguo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work!

}
wi_0_1 = input_tensor
wi_0 = wi_0_1[..., ::2]
wi_1 = wi_0_1[..., 1::2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a smart way to handle interleaving layers!

return local_dir


def check_param_map_keys(param_map_keys, maxtext_state_keys):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a quick note on when this function is used and which model it applies to?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the function name and docstring.

  • Overall: Preprocess maxtext keys for transformation in to_maxtext.py and to_huggingface.py
  • Validates param_mapping coverage, so we can check error before transformation
  • Filters unused keys in param_mapping (where some mt keys might be grouped), so we can choose exact mt keys to transform.

mt_target_shape_final = abstract_leaf_value.shape
if not use_lazy_load and config.scan_layers:
max_logging.log(f"maxtext param: {mt_param_key}")
is_many_mt_key = isinstance(mt_param_key, tuple)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean later on the conversion are categorized into two main types: 1-to-1 mapping (most common case) and 1-to-N mapping, based on this flag?

A very naive question and I don't have any thoughts yet: do you see possibility in N-to-N mapping?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be specific, we have 8 combinations: maxtext_key in [single, many - tuple] x hf_key in [single - unscan, list - scan, list - unscan expert, nested list - scan expert]

1 maxtext_key_single x hf_key_sinlge-unscan: {m: h}
2 maxtext_key_single x hf_key_list-scan: {m: [h_l1, h_l2, ...h_ln]}
3 maxtext_key_single x hf_key_list-unscan_expert: {m: [e1, e2, ..., ek] }
4 maxtext_key_single x hf_key_nested_list-scan_expert: {m: [[e1_l1, ....e1_ln], ..., [em_l1, ..., ek_ln] }
5 maxtext_key_many x hf_key_sinlge-unscan: {(m1, ..., mp): h}
6 maxtext_key_many x hf_key_list-scan: {(m1, ..., mp): [h_l1, h_l2, ...h_ln]}
7 maxtext_key_many x hf_key_list-unscan_expert: {(m1, ..., mp): [e1, e2, ..., ek] }
8 maxtext_key_many x hf_key_nested_list-scan_expert: {(m1, ..., mp): [[e1_l1, ....e1_ln], ..., [em_l1, ..., ek_ln] }

In summary:

  • 1 & 2: most common
  • 3 & 4: present for some MoE model. For instance, qwen3-moe, deepseek, mixtral, N expert is stored as N hf tensor. (On the other had, some does not need this; gpt-oss has N expert store as a single combined hf tensor)
  • 5 & 6: present for some models, where hf key is split into several equal-shape maxtext key (e.g., gpt-oss, and llama4 which hasn't been onboarded)
  • 7 & 8: I don't expect this to happen. I haven't tested this. (probably can work?)

Unfortunately, this has become quite convoluted. Probably I should add more comments?

cc @RissyRan

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! Let's create a bug and track the refactor. This will be confusing. Let's refactor each case into a small helper function.

If you could wrap your added codes into a helper function, it will be a good start!

- Gemma3 multimodal (4B, 12B, 27B).
- Qwen3 (0.6B, 4B, 8B, 14B, 32B).
- Mixtral (8x7B, 8x22B).
| Model Family | Sizes | HF $\to$ Orbax (scan) | HF $\to$ Orbax (unscan) | Orbax (scan) $\to$ HF | Orbax (unscan) $\to$ HF |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this table!

def _initialize_index(self):
"""Fetches and parses the Hugging Face model index file to build a shard map."""
files = list_repo_files(self.model_id, token=self.token)
if self.is_local:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add this option in the README.md?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added --hf_model_path in readme.

load_fn = partial(_loader, tensor_getter, hf_source_keys_or_key, mt_target_shape_final, hook_fn)
load_fn = partial(_loader, tensor_getter, hf_source_keys_or_key, mt_target_shape, hook_fn)
# Stacked mapping
elif isinstance(hf_source_keys_or_key[0], list):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was really confused when onboarding Mixtral models. Do you think we could add a decision tree or something similar at the top?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added more comments for these conditions.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! If you could wrap your added codes into helper functions among that 8 cases, which will be really appreciated :) Also, let's create a bug to track the rest refactor.

mt_target_shape_final = abstract_leaf_value.shape
if not use_lazy_load and config.scan_layers:
max_logging.log(f"maxtext param: {mt_param_key}")
is_many_mt_key = isinstance(mt_param_key, tuple)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! Let's create a bug and track the refactor. This will be confusing. Let's refactor each case into a small helper function.

If you could wrap your added codes into a helper function, it will be a good start!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants