From 171a3c63750784a22de37367cee64ab73f3f4c60 Mon Sep 17 00:00:00 2001 From: Shuning Jin Date: Tue, 6 Jan 2026 22:42:35 +0000 Subject: [PATCH] Checkpoint utility: add gpt-oss to_maxtext & refactor code --- .../tpu/deepseek/v2-16b/test_deepseek.sh | 2 +- .../tpu/deepseek/v3-671b/2_test_deepseek.sh | 2 +- .../scratch_code/generate_hf_golden_logits.py | 9 +- src/MaxText/utils/ckpt_conversion/README.md | 24 +- .../utils/ckpt_conversion/to_huggingface.py | 71 +-- .../utils/ckpt_conversion/to_maxtext.py | 403 ++++++++++++------ .../ckpt_conversion/utils/param_mapping.py | 399 +++++++++-------- .../utils/ckpt_conversion/utils/utils.py | 265 +++++++----- tests/forward_pass_logit_checker.py | 19 +- 9 files changed, 694 insertions(+), 500 deletions(-) diff --git a/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh b/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh index 61687c4ca4..0c2a1817ce 100644 --- a/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh +++ b/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh @@ -10,7 +10,7 @@ # Example Usage: export HF_TOKEN=; export BASE_OUTPUT_PATH=; bash test_deepseek.sh # The golden logit can be generated by: -# python3 -m MaxText.scratch_code.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V2-Lite --output-path=golden_data_deepseek2-16b.jsonl --prompts='I love to;Today is a;What is the' --hf-model-path=$local_bf16_path --not-trust-remote-code +# python3 -m MaxText.scratch_code.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V2-Lite --output-path=golden_data_deepseek2-16b.jsonl --prompts='I love to;Today is a;What is the' --hf-model-path=$local_bf16_path --trust-remote-code=False set -ex diff --git a/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh b/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh index 9cba02720a..4c3fd4bff7 100644 --- a/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh +++ b/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh @@ -8,7 +8,7 @@ # 2. Run logit check, pre-training, fine-tuning, and decoding. # The golden logit can be generated by: -# python3 -m MaxText.scratch_code.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V3 --output-path=golden_data_deepseek3-671b.jsonl --prompts='I love to' --hf-model-path=$local_bf16_path --not-trust-remote-code --hf-load-dtype=bfloat16 +# python3 -m MaxText.scratch_code.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V3 --output-path=golden_data_deepseek3-671b.jsonl --prompts='I love to' --hf-model-path=$local_bf16_path --trust-remote-code=False --hf-load-dtype=bfloat16 set -ex diff --git a/src/MaxText/scratch_code/generate_hf_golden_logits.py b/src/MaxText/scratch_code/generate_hf_golden_logits.py index 73b19a5e0f..e119ab1d34 100644 --- a/src/MaxText/scratch_code/generate_hf_golden_logits.py +++ b/src/MaxText/scratch_code/generate_hf_golden_logits.py @@ -47,6 +47,7 @@ import numpy as np from google.cloud import storage from PIL import Image +from MaxText.inference_utils import str2bool # Load the tokenizer and model from Hugging Face @@ -184,11 +185,11 @@ def main(raw_args=None) -> None: default="float32", help="model_class.from_pretrained: dtype", ) - # variable `args.trust_remote_code` is True by default, False only if with flag `--not-trust-remote-code` parser.add_argument( - "--not-trust-remote-code", - dest="trust_remote_code", - action="store_false", + "--trust-remote-code", + type=str2bool, + required=False, + default=True, help="model_class.from_pretrained: trust_remote_code", ) parser.add_argument( diff --git a/src/MaxText/utils/ckpt_conversion/README.md b/src/MaxText/utils/ckpt_conversion/README.md index edb7c15256..fc073c2ec6 100644 --- a/src/MaxText/utils/ckpt_conversion/README.md +++ b/src/MaxText/utils/ckpt_conversion/README.md @@ -6,10 +6,17 @@ This guide provides instructions for using the scripts that convert model checkp The following models are supported: -- Gemma2 (2B, 9B, 27B). -- Gemma3 multimodal (4B, 12B, 27B). -- Qwen3 (0.6B, 4B, 8B, 14B, 32B). -- Mixtral (8x7B, 8x22B). +| Model Family | Sizes | HF $\to$ Orbax (scan) | HF $\to$ Orbax (unscan) | Orbax (scan) $\to$ HF | Orbax (unscan) $\to$ HF | +| :--- | :--- | :---: | :---: | :---: | :---: | +| **Gemma2** | 2B, 9B, 27B | √ | √ | √ | √ | +| **Gemma3** (Multimodal) | 4B, 12B, 27B | - | √ | - | √ | +| **Llama3.1** | 8B, 70B, 450B | √ | √ | √ | √ | +| **Qwen3** | 0.6B, 4B, 8B, 14B, 32B | √ | √ | √ | √ | +| **Qwen3 MoE** | 30B, 235B, 480B | √ | √ | √ | √ | +| **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ | +| **GPT-OSS** | 20B, 120B | √ | √ | √ | √ | +| **DeepSeek3** | 671B | - | - | √ | - | + ## Prerequisites - Hugging Face requires Pytorch. @@ -42,8 +49,9 @@ python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml * `use_multimodal`: Indicates if multimodality is used, important for Gemma3. * `hf_access_token`: Your Hugging Face token. * `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory is `Maxtext/tmp`. + * `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. + * `--hf_model_path` (optional): Specifies a local directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/2f77e7b5fcc4b580bc2d109525c362f3d9056ec9/src/MaxText/utils/ckpt_conversion/utils/utils.py#L54-L82) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek. -\*\**It only converts the official version of Hugging Face model. You can refer the supported official version in HF_IDS in `src/MaxText/utils/ckpt_conversion/utils/utils.py`* ## MaxText to Hugging Face @@ -62,6 +70,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base scan_layers=false \ use_multimodal=false \ hf_access_token= \ + weight_dtype=bfloat16 ``` **Key arguments:** @@ -72,6 +81,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base * `hf_access_token`: Your Hugging Face token. * `use_multimodal`: Indicates if multimodality is used, important for Gemma3. * `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS), Hugging Face Hub or local. If not set, the default output directory is `Maxtext/tmp`. + * `weight_dtype`: dtype for MaxText weights. It affects the resulting HF weight dtype. Default value is `float32`. We recommend using `bfloat16` to save memory and speed up conversion. ## Verifying conversion correctness @@ -87,11 +97,11 @@ python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \ model_name= \ scan_layers=false \ max_prefill_predict_length=4 \ - max_target_length=8 \ + max_target_length=8 \ use_multimodal=false \ --run_hf_model=True \ --hf_model_path= \ - --max_kl_div=0.015 \ + --max_kl_div=0.015 ``` **Key arguments:** diff --git a/src/MaxText/utils/ckpt_conversion/to_huggingface.py b/src/MaxText/utils/ckpt_conversion/to_huggingface.py index a5f09f9637..9c7d16c31b 100644 --- a/src/MaxText/utils/ckpt_conversion/to_huggingface.py +++ b/src/MaxText/utils/ckpt_conversion/to_huggingface.py @@ -72,8 +72,12 @@ ) from MaxText.utils.ckpt_conversion.utils.hf_shape import HF_SHAPE from MaxText.utils.ckpt_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS -from MaxText.utils.ckpt_conversion.utils.utils import process_maxtext_param, save_model_files, HF_IDS - +from MaxText.utils.ckpt_conversion.utils.utils import ( + validate_and_filter_param_map_keys, + process_maxtext_param, + save_model_files, + HF_IDS, +) os.environ["JAX_PLATFORMS"] = "cpu" os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16" @@ -107,59 +111,6 @@ def _get_model_mappings( } -def _check_param_map_keys(param_map_keys, maxtext_state_keys): - """Validates map coverage, handles N-to-1 mappings, and filters unused keys. - - Ensures every MaxText checkpoint key (`maxtext_state_keys`) is covered by - the flattened parameter map. Keys in the map that are not present in the - checkpoint (common for multi-variant maps like gemma3, qwen3, deepseek) are skipped. - - Tuple keys represent N-to-1 mappings (multiple MaxText keys combining into one - target key) and are only returned if all constituent keys exist in the checkpoint. - - Args: - param_map_keys: Keys from the parameter mapping (strings or N-to-1 tuples). - maxtext_state_keys: Set of parameter keys loaded from the MaxText checkpoint. - - Returns: - A list of 'filtered' mapping keys (strings or tuples) that are fully present - and valid based on `maxtext_state_keys`. - - Raises: - ValueError: If `maxtext_state_keys` is NOT a subset of the flattened - `param_map_keys`. - """ - flattened_map_keys = set() - for key in param_map_keys: - if isinstance(key, tuple): - flattened_map_keys.update(key) - else: - flattened_map_keys.add(key) - - # every maxtext state key must be covered by param map - missing_keys = maxtext_state_keys - flattened_map_keys - if missing_keys: - raise ValueError( - "maxtext_state_dict must be a subset of flattened param_map" - + f"\nparam map\n{param_map_keys}" - + f"\nmaxtext:\n{maxtext_state_keys}" - ) - - # param map may have extra keys - extra_keys = flattened_map_keys - maxtext_state_keys - if extra_keys: - max_logging.log(f"Warning: extra keys in param_map are skipped: {extra_keys}") - - # skip extra keys in param map - filtered_map_keys = [] - for key in param_map_keys: - if (isinstance(key, str) and key in maxtext_state_keys) or ( - isinstance(key, tuple) and all(k in maxtext_state_keys for k in key) - ): - filtered_map_keys.append(key) - return filtered_map_keys - - def main(argv: Sequence[str]) -> None: """Main function to convert a MaxText checkpoint to HuggingFace format. @@ -180,6 +131,7 @@ def main(argv: Sequence[str]) -> None: config.load_full_state_path == "" ), "This script expects parameters, not a full state. Use generate_param_only_checkpoint first if needed." max_utils.print_system_information() + overall_start = time.time() # Load Maxtext checkpoint max_logging.log("\nLoading Orbax checkpoint...") @@ -189,7 +141,7 @@ def main(argv: Sequence[str]) -> None: rng, rng_load_params = jax.random.split(rng) # load params from maxengine loaded_params_from_engine = engine.load_params(rng_load_params) - max_logging.log(f"Elapse: {(time.time() - start) / 60:.2f} min") + max_logging.log(f"Elapse for checkpoint load: {(time.time() - start) / 60:.2f} min") if not config.base_output_directory: output_directory = f"tmp/{config.run_name}" @@ -239,7 +191,7 @@ def main(argv: Sequence[str]) -> None: # The param_map may contain tuples as keys, which represent N-to-1 mappings from maxtext to huggingface # Check maxtext_state_dict is a subset of flattened param_map # Skip extra keys from param_map - filtered_map_keys = _check_param_map_keys(param_map.keys(), maxtext_state_dict.keys()) + filtered_map_keys = validate_and_filter_param_map_keys(param_map.keys(), maxtext_state_dict.keys()) # Iterate through the parameter map to transform and collect weights. # This loop handles both simple 1-to-1 mappings and complex N-to-1 mappings @@ -260,7 +212,7 @@ def main(argv: Sequence[str]) -> None: processed_params_list.extend(processed_params) transformed_hf_weights = dict(processed_params_list) - max_logging.log(f"Elapse: {(time.time() - start) / 60:.2f} min") + max_logging.log(f"Elapse for transform: {(time.time() - start) / 60:.2f} min") # 5. Save in HuggingFace Format if not transformed_hf_weights: @@ -277,7 +229,8 @@ def main(argv: Sequence[str]) -> None: output_dir=output_directory, ) max_logging.log(f"✅ MaxText model successfully saved in HuggingFace format at {output_directory}") - max_logging.log(f"Elapse: {(time.time() - start) / 60:.2f} min") + max_logging.log(f"Elapse for save: {(time.time() - start) / 60:.2f} min") + max_logging.log(f"Overall Elapse: {(time.time() - overall_start) / 60:.2f} min") if __name__ == "__main__": diff --git a/src/MaxText/utils/ckpt_conversion/to_maxtext.py b/src/MaxText/utils/ckpt_conversion/to_maxtext.py index 2caa4144a7..6a4e0323c9 100644 --- a/src/MaxText/utils/ckpt_conversion/to_maxtext.py +++ b/src/MaxText/utils/ckpt_conversion/to_maxtext.py @@ -25,9 +25,11 @@ Defaults to "./mt_output/". scan_layers: (bool) Whether the MaxText model was trained with scanned layers. This must match the training configuration of the checkpoint. - lazy_load: (bool) If True, uses an on-demand loading strategy to minimize RAM + --lazy_load_tensors: (bool) If True, uses an on-demand loading strategy to minimize RAM usage during conversion. Recommended if, 2 * model_size (GB) >= system RAM Defaults to False. + --hf_model_path: (Optional) Specify a local HF path, rather than the default repo `HF_IDS[model_name]`. + Useful for locally dequantized HF model like GPT-OSS or DeepSeek. Environment Variables: HF_AUTH_TOKEN: (Required) HuggingFace authentication token, needed to @@ -56,6 +58,7 @@ import argparse import os +import time import sys import json import threading @@ -70,6 +73,7 @@ from tqdm import tqdm from huggingface_hub import hf_hub_download, list_repo_files from safetensors import safe_open +import absl from orbax.checkpoint import type_handlers from MaxText import checkpointing @@ -82,10 +86,12 @@ from MaxText.layers import models, quantizations from MaxText.checkpointing import save_checkpoint from MaxText.utils.ckpt_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING -from MaxText.utils.ckpt_conversion.utils.utils import apply_hook_fns, HF_IDS, print_ram_usage, get_hf_model +from MaxText.utils.ckpt_conversion.utils.utils import apply_hook_fns, HF_IDS, print_ram_usage, get_hf_model, validate_and_filter_param_map_keys jax.config.update("jax_platform_name", "cpu") +absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log + class MemoryMonitorTqdm(tqdm): """Custom tqdm class that displays memory usage in the progress bar.""" @@ -144,6 +150,8 @@ class LazyHFLoader: def __init__(self, model_id, token): self.model_id = model_id self.token = token + # Whether loads from local directory + self.is_local = os.path.isdir(self.model_id) self.shard_map = {} self.current_shard_name = None self.current_shard_content = {} @@ -164,7 +172,10 @@ def __setstate__(self, state): def _initialize_index(self): """Fetches and parses the Hugging Face model index file to build a shard map.""" - files = list_repo_files(self.model_id, token=self.token) + if self.is_local: + files = os.listdir(self.model_id) + else: + files = list_repo_files(self.model_id, token=self.token) # Prefer safetensors if "model.safetensors.index.json" in files: @@ -178,7 +189,10 @@ def _initialize_index(self): # Download and parse the index max_logging.log(f"Loading index file: {index_file}") - index_path = hf_hub_download(repo_id=self.model_id, filename=index_file, token=self.token) + if self.is_local: + index_path = os.path.join(self.model_id, index_file) + else: + index_path = hf_hub_download(repo_id=self.model_id, filename=index_file, token=self.token) with open(index_path, "r", encoding="utf-8") as f: index_data = json.load(f) self.shard_map = index_data["weight_map"] @@ -203,9 +217,12 @@ def get_tensor(self, key: str) -> np.ndarray: # You might need advanced fuzzy matching here if you encounter errors. raise ValueError(f"Key {key} not found in HF checkpoint index.") - # STEP 1: Download outside the lock. - # multiple threads can download different shards at the same time. - local_path = hf_hub_download(repo_id=self.model_id, filename=shard_name, token=self.token) + if self.is_local: + local_path = os.path.join(self.model_id, shard_name) + else: + # STEP 1: Download outside the lock. + # multiple threads can download different shards at the same time. + local_path = hf_hub_download(repo_id=self.model_id, filename=shard_name, token=self.token) # STEP 2: Lock ONLY the reading into RAM. # This prevents multiple threads from simultaneously allocating large chunks of RAM. @@ -258,6 +275,9 @@ def __array__(self, dtype=None): # Re-raise the original exception so it doesn't get masked by "object __array__..." raise + if not isinstance(self.shape, list) and arr.shape != self.shape: + raise ValueError(f"Shape mismatch for tensor '{self.name}'. Expected {self.shape}, but got {arr.shape}.") + # Ensure it's a standard numpy array (converts JAX arrays if necessary) if not isinstance(arr, np.ndarray): arr = np.array(arr) @@ -293,6 +313,51 @@ async def serialize(self, value, *args, **kwargs): type_handlers.register_type_handler(LazyTensor, LazyTensorHandler(), override=True) +def get_maxtext_model_info(config): + """Initializes the abstract MaxText model and returns parameter mapping information. + + Args: + config: The MaxText configuration object. + + Returns: + maxtext_abstract_dict: A dictionary mapping MaxText parameter keys to a tuple + (index, target_shape), where 'index' is the position of the parameter in the + flattened parameter list. + abstract_params_treedef: The tree structure definition of the abstract model parameters. + """ + # Setup JAX distributed system and mesh + devices_array = maxtext_utils.create_device_mesh(config) + mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) + + max_logging.log("Initializing MaxText abstract model...") + quant = quantizations.configure_quantization(config) + maxtext_model_flax = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + + # Get abstract model structure (name, shape) without materializing the weights to save memory + abstract_params_tree = maxtext_utils.get_abstract_param(maxtext_model_flax, config)["params"] + + abstract_params_flat, _ = jax.tree_util.tree_flatten_with_path(abstract_params_tree) + # Standardize abstract tree for later unflattening + abstract_params_tree = jax.tree.map( + lambda _: 0, + abstract_params_tree, + is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned), + ) + abstract_params_treedef = jax.tree_util.tree_structure(abstract_params_tree) + + max_logging.log("MaxText abstract model and state initialized.") + + # preprocess state + maxtext_abstract_dict = {} + for mt_target_idx, (path_tuple, abstract_leaf_value) in enumerate(abstract_params_flat): + key_parts = [k.key for k in path_tuple if hasattr(k, "key")] + mt_param_key = "params-" + "-".join(key_parts) + mt_target_shape = abstract_leaf_value.shape + maxtext_abstract_dict[mt_param_key] = (mt_target_idx, mt_target_shape) + + return maxtext_abstract_dict, abstract_params_treedef + + def _build_multi_axis_stacked_tensor( hf_source_keys: List[List[str]], tensor_getter_fn: Callable[[str], np.ndarray], @@ -356,11 +421,13 @@ def _build_single_axis_stacked_tensor( The final, assembled NumPy array for the MaxText parameter. """ tensors_to_stack = [] - # Heuristic to determine if we are stacking layers or experts. - # If the number of items to stack equals the number of layers, it's a standard - # scanned layer, and we use the configured param_scan_axis. Otherwise, it's - # an unscanned MoE layer, and we stack along the expert axis (0). - axis_to_stack = config.param_scan_axis if len(hf_source_keys) == config.base_num_decoder_layers else 0 + + if config.scan_layers: + # If it's a standard scanned layer, we use the configured param_scan_axis. + axis_to_stack = config.param_scan_axis + else: + # Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0). + axis_to_stack = 0 # The hook function needs the shape of an individual slice, not the full stacked tensor. # We calculate it by removing the stacking dimension from the final target shape. @@ -377,7 +444,142 @@ def _build_single_axis_stacked_tensor( return np.stack(tensors_to_stack, axis=axis_to_stack) +def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config): + """Determine the loading function for HF keys. + HF keys can take four forms: + Case 1: Unscanned (single string) + Case 2: Scanned (list of strings) + Case 3: Unscanned with expert stacking (list of strings) + Case 4: Scanned with expert stacking (nested list of strings) + """ + load_fn = None + if not isinstance(hf_source_keys_or_key, list): + # Case 1: Single hf key (str) + def _loader(getter, key, shape, hook): + return apply_hook_fns(getter(key), shape, hook) + + load_fn = partial(_loader, tensor_getter, hf_source_keys_or_key, mt_target_shape_or_shapes, hook_fn) + # Stacked mapping + elif not isinstance(hf_source_keys_or_key[0], list): + # Case 2 or 3: Single-Axis Stacked hf keys (un-nested list) + load_fn = partial( + _build_single_axis_stacked_tensor, + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_shape_or_shapes, + config, + ) + else: + # isinstance(hf_source_keys_or_key[0], list) + # Case 4: Multi-Axis Stacked hf keys (nested list) + load_fn = partial( + _build_multi_axis_stacked_tensor, + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_shape_or_shapes, + config, + ) + return load_fn + + +def _get_maxtext_indices_and_shapes(mt_param_key_or_keys, maxtext_abstract_dict): + """Resolves MaxText key(s) to target indices and shapes. + + The index is the parameter's order in `maxtext_abstract_dict.keys()`. + This function handles two forms of MaxText keys: + - `atomic_mt_key`: A single string representing one MaxText parameter that map to HF parameter(s). + - `composite_mt_key`: A tuple of strings for multiple MaxText parameters that map to HF parameter(s). + """ + is_composite_mt_key = isinstance(mt_param_key_or_keys, tuple) + # atomic_mt_key + if not is_composite_mt_key: + mt_target_idx, mt_target_shape = maxtext_abstract_dict[mt_param_key_or_keys] + return mt_target_idx, mt_target_shape + # composite_mt_key + mt_target_indices, mt_target_shapes = [], [] + for mt_param_key in mt_param_key_or_keys: + mt_target_idx, mt_target_shape = maxtext_abstract_dict[mt_param_key] + mt_target_indices.append(mt_target_idx) + mt_target_shapes.append(mt_target_shape) + return mt_target_indices, mt_target_shapes + + +def _get_maxtext_weight( + load_fn, + mt_target_idx_or_indices, + mt_target_shape_or_shapes, + mt_param_key_or_keys, + final_mt_weights, + config, + use_lazy_load, +): + """Loads Hugging Face parameters and converts them to MaxText parameters. + + This function handles loading based on tensor mode (eager or lazy) and + processes MaxText keys, which can be `atomic_mt_key` or `composite_mt_key`. + """ + is_composite_mt_key = isinstance(mt_param_key_or_keys, tuple) + if not use_lazy_load: + # Case 1: Eager mode + # In eager mode, we execute the function immediately to get the + # NumPy array and append it to our list of weights. + final_mt_tensor_numpy = load_fn() + if not is_composite_mt_key: + # Case 1.1: Eager mode, `atomic_mt_key` + final_mt_weights[mt_target_idx_or_indices] = final_mt_tensor_numpy + if final_mt_tensor_numpy.shape != mt_target_shape_or_shapes: + raise ValueError( + f"Shape mismatch for {mt_param_key_or_keys}: Expected {mt_target_shape_or_shapes}, " + f"got {final_mt_tensor_numpy.shape}" + ) + else: + # Case 1.2: Eager mode, `composite_mt_key` + # The hook returns a tensor that can be split in last dim. + # In eager mode, we can just split the materialized tensor. + for i, mt_target_idx in enumerate(mt_target_idx_or_indices): + final_mt_weights[mt_target_idx] = final_mt_tensor_numpy[..., i] + if final_mt_weights[mt_target_idx].shape != mt_target_shape_or_shapes[i]: + raise ValueError( + f"Shape mismatch for {mt_param_key_or_keys[i]}: Expect {mt_target_shape_or_shapes[i]}, " + f"got {final_mt_weights[mt_target_idx].shape}" + ) + else: + # Case 2: Lazy mode + # In lazy mode, we don't execute the loading/transformation function + # immediately. Instead, we wrap it in a `LazyTensor` object. This + # object acts as a placeholder that holds all the information needed + # to load the tensor later (the `load_fn`, shape, dtype). + # The actual data will only be loaded when Orbax calls `__array__` + # on this object during the saving process. + final_mt_tensor_numpy = LazyTensor(load_fn, mt_target_shape_or_shapes, config.weight_dtype, name=mt_param_key_or_keys) + if not is_composite_mt_key: + # Case 2.1: Lazy mode, `atomic_mt_key` + final_mt_weights[mt_target_idx_or_indices] = final_mt_tensor_numpy + else: + # Case 2.2: Lazy mode, `composite_mt_key` + # For a composite key, the hook returns a tensor that can be split in last dim. + # For lazy loading, we can't split the tensor until it's loaded. + # We create multiple LazyTensors, each responsible for loading the + # full source tensor but then slicing its piece. Parent HF tensor is loaded repeatedly. + for i, mt_target_idx in enumerate(mt_target_idx_or_indices): + + def _slicing_loader(base_loader, slice_idx): + return np.array(base_loader)[..., slice_idx] + + # Each LazyTensor gets a new load_fn that wraps the original and applies the slice. + slicing_load_fn = partial(_slicing_loader, final_mt_tensor_numpy, i) + final_mt_weights[mt_target_idx] = LazyTensor( + slicing_load_fn, + mt_target_shape_or_shapes[i], + config.weight_dtype, + name=mt_param_key_or_keys[i], + ) + + def main(args: Sequence[str], test_args: Sequence[str]) -> None: + overall_start = time.time() # Check if the user is using an Instruct version. If so, use the base model architecture for i, arg in enumerate(args): if arg.startswith("model_name="): @@ -389,22 +591,24 @@ def main(args: Sequence[str], test_args: Sequence[str]) -> None: args[i] = f"model_name={model_name_arg}" break - config = pyconfig.initialize(args) # check the supported model ids if model_name_original not in HF_IDS: raise ValueError(f"Unsupported model name: {model_name_original}. Supported models are: {list(HF_IDS.keys())}") - model_id = HF_IDS[model_name_original] + if not test_args.hf_model_path: + model_id = HF_IDS[model_name_original] + else: + model_id = test_args.hf_model_path + + # Initialize maxtext config + config = pyconfig.initialize(args) max_utils.print_system_information() + if not config.base_output_directory: output_directory = f"tmp/{config.run_name}" else: output_directory = config.base_output_directory - # Setup JAX distributed system and mesh - devices_array = maxtext_utils.create_device_mesh(config) - mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) - hf_token = config.hf_access_token use_lazy_load = test_args.lazy_load_tensors @@ -415,11 +619,13 @@ def main(args: Sequence[str], test_args: Sequence[str]) -> None: hf_state_dict_numpy = None hf_loader = None + # Define the appropriate tensor getter based on mode if use_lazy_load: max_logging.log(f"Lazy loading ENABLED. Initializing LazyHFLoader for: {model_id}...") hf_loader = LazyHFLoader(model_id, hf_token) hf_config_obj = AutoConfig.from_pretrained(model_id, token=hf_token) print_ram_usage("After LazyLoader init") + tensor_getter = hf_loader.get_tensor else: max_logging.log(f"Lazy loading DISABLED. Loading full HuggingFace model: {model_id}...") hf_config_obj = AutoConfig.from_pretrained(model_id, token=hf_token) @@ -432,39 +638,17 @@ def main(args: Sequence[str], test_args: Sequence[str]) -> None: max_logging.log("HuggingFace model loaded and converted to NumPy.") print_ram_usage("After full HF model load") - checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( - output_directory, - enable_checkpointing=True, - use_async=False, # Synchronous saving for simplicity in conversion script - save_interval_steps=1, # Save at step 0 - use_ocdbt=config.checkpoint_storage_use_ocdbt, - use_zarr3=config.checkpoint_storage_use_zarr3, - ) - - max_logging.log("Initializing MaxText abstract model...") - quant = quantizations.configure_quantization(config) - maxtext_model_flax = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - - # Get abstract model structure (name, shape) without materializing the weights to save memory - abstract_params_tree = maxtext_utils.get_abstract_param(maxtext_model_flax, config)["params"] - - abstract_params_flat, _ = jax.tree_util.tree_flatten_with_path(abstract_params_tree) - # Standardize abstract tree for later unflattening - abstract_params_tree = jax.tree.map( - lambda _: 0, - abstract_params_tree, - is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned), - ) - abstract_params_treedef = jax.tree_util.tree_structure(abstract_params_tree) - del abstract_params_tree + def _eager_getter(key): + if key not in hf_state_dict_numpy: + raise ValueError(f"HuggingFace key {key} not found in state_dict.") + return hf_state_dict_numpy[key] - max_logging.log("MaxText abstract model and state initialized.") + tensor_getter = _eager_getter # Get parameter mappings and hooks # example of param mapping (gemma2, maxtext:huggingface): # "params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_global-scale": # f"model.layers.{global_layer_idx}.input_layernorm.weight", - model_key = config.model_name param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_obj.to_dict(), config, config.scan_layers) @@ -473,87 +657,61 @@ def main(args: Sequence[str], test_args: Sequence[str]) -> None: hook_fn_map_mt = HOOK_FNS[model_key](hf_config_obj.to_dict(), config, config.scan_layers, saving_to_hf=False) max_logging.log("Parameter mappings and hooks obtained.") - max_logging.log("Starting weight transformation...") - final_mt_weights = [] + checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( + output_directory, + enable_checkpointing=True, + use_async=False, # Synchronous saving for simplicity in conversion script + save_interval_steps=1, # Save at step 0 + use_ocdbt=config.checkpoint_storage_use_ocdbt, + use_zarr3=config.checkpoint_storage_use_zarr3, + ) - # Define the appropriate tensor getter based on mode - if use_lazy_load: - tensor_getter = hf_loader.get_tensor - else: + maxtext_abstract_dict, abstract_params_treedef = get_maxtext_model_info(config) - def _eager_getter(key): - if key not in hf_state_dict_numpy: - raise ValueError(f"HuggingFace key {key} not found in state_dict.") - return hf_state_dict_numpy[key] + # Weight transformation + max_logging.log("Starting weight transformation...") + start = time.time() + final_mt_weights = [None] * len(maxtext_abstract_dict) - tensor_getter = _eager_getter + # Preprocess key + filtered_map_keys = validate_and_filter_param_map_keys(param_map_mt_to_hf.keys(), maxtext_abstract_dict.keys()) - for path_tuple, abstract_leaf_value in MemoryMonitorTqdm( - abstract_params_flat, desc="Transforming weights", unit="param", leave=True, dynamic_ncols=True + for mt_param_key_or_keys in MemoryMonitorTqdm( + filtered_map_keys, desc="Transforming weights", unit="param", leave=True, dynamic_ncols=True ): - key_parts = [k.key for k in path_tuple if hasattr(k, "key")] - mt_param_key = "params-" + "-".join(key_parts) - mt_target_shape_final = abstract_leaf_value.shape - - hf_source_keys_or_key = param_map_mt_to_hf.get(mt_param_key) + if not use_lazy_load and config.scan_layers: + max_logging.log(f"maxtext param: {mt_param_key_or_keys}") + hf_source_keys_or_key = param_map_mt_to_hf.get(mt_param_key_or_keys) if hf_source_keys_or_key is None: - raise ValueError(f"MaxText parameter {mt_param_key} not found in mapping.") - - hook_fn = hook_fn_map_mt.get(mt_param_key) - - # Determine the loading function for this specific parameter - load_fn = None - if not isinstance(hf_source_keys_or_key, list): - # Case 1: Simple 1-to-1 mapping - def _loader(getter, key, shape, hook): - return apply_hook_fns(getter(key), shape, hook) - - load_fn = partial(_loader, tensor_getter, hf_source_keys_or_key, mt_target_shape_final, hook_fn) - else: - # Stacked mapping - if isinstance(hf_source_keys_or_key[0], list): - # Case 2: Multi-Axis Stacked - load_fn = partial( - _build_multi_axis_stacked_tensor, - hf_source_keys_or_key, - tensor_getter, - hook_fn, - mt_target_shape_final, - config, - ) - else: - # Case 3: Single-Axis Stacked - load_fn = partial( - _build_single_axis_stacked_tensor, - hf_source_keys_or_key, - tensor_getter, - hook_fn, - mt_target_shape_final, - config, - ) - - # Execute based on mode - if use_lazy_load: - # In lazy mode, we don't execute the loading/transformation function - # immediately. Instead, we wrap it in a `LazyTensor` object. This - # object acts as a placeholder that holds all the information needed - # to load the tensor later (the `load_fn`, shape, dtype). - # The actual data will only be loaded when Orbax calls `__array__` - # on this object during the saving process. - final_mt_weights.append(LazyTensor(load_fn, mt_target_shape_final, abstract_leaf_value.dtype, name=mt_param_key)) - else: - # In eager mode, we execute the function immediately to get the - # NumPy array and append it to our list of weights. - final_mt_tensor_numpy = load_fn() - if final_mt_tensor_numpy.shape != mt_target_shape_final: - raise ValueError( - f"Shape mismatch for {mt_param_key}: Expected {mt_target_shape_final}, got {final_mt_tensor_numpy.shape}" - ) - final_mt_weights.append(final_mt_tensor_numpy) - - del abstract_params_flat, hf_state_dict_numpy + raise ValueError(f"MaxText parameter {mt_param_key_or_keys} not found in mapping.") + hook_fn = hook_fn_map_mt.get(mt_param_key_or_keys) + + # Step 1: Resolves MaxText key(s) to target indices and shapes + # based on MaxText key form (`atomic_mt_key` or `composite_mt_key`) + mt_target_idx_or_indices, mt_target_shape_or_shapes = _get_maxtext_indices_and_shapes( + mt_param_key_or_keys, maxtext_abstract_dict + ) + + # Step 2: Determine the loading function for hf key + # based on hf_key form (unscanned, scanned, unscanned with expert stacking, or scanned with expert stacking) + load_fn = _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config) + + # Step 3: Load hf keys and convert to maxtext keys + # based on tensor load mode (lazy, eager) and MaxText key form (`atomic_mt_key` or `composite_mt_key`) + _get_maxtext_weight( + load_fn, + mt_target_idx_or_indices, + mt_target_shape_or_shapes, + mt_param_key_or_keys, + final_mt_weights, + config, + use_lazy_load, + ) + + del hf_state_dict_numpy max_logging.log("Weight transformation preparation complete.") + max_logging.log(f"Elapse for transform: {(time.time() - start) / 60:.2f} min") print_ram_usage("Before creating full JAX tree") # Create final MaxText parameters tree @@ -566,6 +724,7 @@ def _loader(getter, key, shape, hook): del final_params_for_state print_ram_usage("Before saving") + start = time.time() if checkpoint_manager is not None: if use_lazy_load: max_logging.log("Starting checkpoint save (loading weights just-in-time)...") @@ -582,6 +741,8 @@ def _loader(getter, key, shape, hook): print_ram_usage("Program Ends") max_logging.log(f"Conversion complete. Checkpoint saved to {output_directory}") + max_logging.log(f"Elapse for save: {(time.time() - start) / 60:.2f} min") + max_logging.log(f"Overall Elapse: {(time.time() - overall_start) / 60:.2f} min") if __name__ == "__main__": @@ -596,9 +757,13 @@ def _loader(getter, key, shape, hook): default=False, help="Whether to use lazy loading of HF tensors.", ) + # if not specified, default to MaxText.utils.ckpt_conversion.utils.utils.HF_IDS[model_name] + parser.add_argument( + "--hf_model_path", type=str, required=False, default="", help="local path to hf model, or custom remote hf repo" + ) local_args, _ = parser.parse_known_args() model_args = sys.argv - to_remove_args = ["--lazy_load_tensors"] + to_remove_args = ["--lazy_load_tensors", "--hf_model_path"] for a in to_remove_args: model_args = [s for s in model_args if not s.startswith(a)] main(model_args, local_args) diff --git a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py index 24606ced46..a38e77c8f4 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py +++ b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py @@ -18,10 +18,21 @@ MaxText and Hugging Face formats for various architectures (e.g., Gemma, Qwen). It provides two key types of mappings for each model: -1. **Parameter Name Mappings (`PARAM_MAPPING`)**: Dictionaries that map the string - name of a parameter in a MaxText checkpoint to its corresponding name in a - Hugging Face checkpoint. These mappings are generated by functions like - `GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING`. +1. **Parameter Name Mappings (`PARAM_MAPPING`)**: Dictionaries that map a MaxText + parameter key to its corresponding Hugging Face parameter(s). These mappings are + generated by functions like `GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING`. + + **Key: MaxText parameters, with following forms:** + - `atomic_mt_key`: A single string representing one MaxText parameter. + - `composite_mt_key`: A tuple of strings representing multiple MaxText parameters. (e.g., GPT-OSS) + + **Value: corresponding Hugging Face parameters, with following forms:** + - `unscanned`: A single string. + - `scanned`: A list of strings, to be stacked along the layer axis. + - `unscanned with expert stacking`: A list of strings, to be stacked along the expert axis. + - `scanned with expert stacking`: A nested list of strings, to be stacked along both layer and expert axes. + Note: Expert stacking only applies a subset of MoE models (e.g., Qwen MoE, DeepSeek, Mixtral), + but not others (e.g., GPT-OSS). 2. **Hook Functions (`HOOK_FNS`)**: Dictionaries that map a MaxText parameter name to a specific transformation function (a "hook"). These hooks handle @@ -57,9 +68,9 @@ def GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False generates mappings for individual, unscanned layers. Defaults to False. Returns: - dict: A mapping where keys are MaxText parameter names and values are the - corresponding Hugging Face parameter names. For scanned text layers, the - value is a list of Hugging Face names. + dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names). Values + are either a single Hugging Face parameter name (unscanned form) or a list of + Hugging Face parameter names (scanned form) for stacked text layers. """ tcfg = config["text_config"] vcfg = config["vision_config"] @@ -309,10 +320,9 @@ def GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False single tensor. Defaults to False. Returns: - dict: A mapping where keys are MaxText parameter paths and values are - either single strings (HF parameter path) for unscanned parameters or - lists of strings (HF parameter paths) for stacked layers when - `scan_layers=True`. + dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter name). + Values are either a single string (unscanned form) or a list of strings + (scanned form) for stacked layers when `scan_layers=True`. Notes: - MaxText uses a paired layer approach where two HF decoder layers are @@ -467,62 +477,41 @@ def pad_hf_embedding_layer(input_tensor, target_shape): # TODO(wenxindongwork), Perhaps, this dtype should be the activation dtype normalizer = np.dtype("float32").type(config["hidden_size"] ** 0.5) - def to_hf(): + if saving_to_hf: target_tensor = input_tensor[: target_shape[0], : target_shape[1]] target_tensor = target_tensor / normalizer target_tensor = target_tensor.astype(input_tensor.dtype) return target_tensor - - def from_hf(): + else: target_tensor = np.zeros(target_shape, dtype=input_tensor.dtype) target_tensor[: input_tensor.shape[0], : input_tensor.shape[1]] = input_tensor target_tensor = target_tensor * normalizer target_tensor = target_tensor.astype(input_tensor.dtype) return target_tensor - if saving_to_hf: - return to_hf() - else: - return from_hf() - def reshape_kernel(input_tensor, target_shape): - def to_hf(): + if saving_to_hf: flipped_target_shape = np.flip(np.array(target_shape)) return input_tensor.reshape(flipped_target_shape).T - - def from_hf(): - return input_tensor.T.reshape(target_shape) - - if saving_to_hf: - return to_hf() else: - return from_hf() + return input_tensor.T.reshape(target_shape) def scale_rmsnorm_layer(input_tensor, target_shape): - def to_hf(): - return (input_tensor - 1.0).reshape(target_shape) - - def from_hf(): - return (input_tensor + 1.0).reshape(target_shape) - if saving_to_hf: - return to_hf() + return (input_tensor - 1.0).reshape(target_shape) else: - return from_hf() + return (input_tensor + 1.0).reshape(target_shape) def scale_query_layer(input_tensor, target_shape): - def to_hf(): + if saving_to_hf: depth_scale = np.dtype("float32").type(np.sqrt(config["head_dim"])) return (input_tensor * depth_scale).astype(input_tensor.dtype) - - def from_hf(): + else: depth_scale = np.dtype("float32").type(1 / np.sqrt(config["head_dim"])) return (input_tensor * depth_scale).astype(input_tensor.dtype) - if saving_to_hf: - return to_hf() - else: - return from_hf() + # hook order does not affect result + query_hook_chain = [reshape_kernel, scale_query_layer] mapping = { "params-token_embedder-embedding": pad_hf_embedding_layer, @@ -531,14 +520,8 @@ def from_hf(): if scan_layers: mapping = { **mapping, - "params-decoder-layers-self_attention_global-query-kernel": [ - reshape_kernel, - scale_query_layer, - ], - "params-decoder-layers-self_attention_local-query-kernel": [ - reshape_kernel, - scale_query_layer, - ], + "params-decoder-layers-self_attention_global-query-kernel": query_hook_chain, + "params-decoder-layers-self_attention_local-query-kernel": query_hook_chain, "params-decoder-layers-self_attention_global-key-kernel": reshape_kernel, "params-decoder-layers-self_attention_local-key-kernel": reshape_kernel, "params-decoder-layers-self_attention_global-value-kernel": reshape_kernel, @@ -564,14 +547,8 @@ def from_hf(): for maxtext_layer_idx in range(nlayers // 2): mapping = { **mapping, - f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-query-kernel": [ - reshape_kernel, - scale_query_layer, - ], - f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-query-kernel": [ - reshape_kernel, - scale_query_layer, - ], + f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-query-kernel": query_hook_chain, + f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-query-kernel": query_hook_chain, f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-key-kernel": reshape_kernel, f"params-decoder-layers_{maxtext_layer_idx}-self_attention_local-key-kernel": reshape_kernel, f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-value-kernel": reshape_kernel, @@ -610,9 +587,10 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False) layers. Defaults to False. Returns: - dict: A mapping from MaxText parameter names to Hugging Face parameter - names. For scanned or MoE layers, the value may be a list or a nested - list of names. + dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names). + Values are Hugging Face parameter names in one of four forms: unscanned (string), + scanned (list of strings), unscanned with expert stacking (list of strings), + or scanned with expert stacking (nested list of strings). """ n_layers = config["num_hidden_layers"] num_experts = config.get("num_experts", 0) @@ -815,7 +793,14 @@ def reshape_kernel(input_tensor, target_shape): def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): - """Returns mapping from MaxText to HuggingFace Deepseek weight paths using f-strings.""" + """Generates a parameter mapping from MaxText to HuggingFace Deepseek weight paths. + + Returns: + dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names). + Values are Hugging Face parameter names in one of four forms: unscanned (string), + scanned (list of strings), unscanned with expert stacking (list of strings), + or scanned with expert stacking (nested list of strings). + """ # TODO(shuningjin): add unscan support, b/457820735 if not scan_layers: raise NotImplementedError("This conversion only supports scanned MaxText models.") @@ -886,7 +871,7 @@ def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=Fal def DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): - """Creates parameter transformation functions for Deepseek using f-strings.""" + """Creates parameter transformation functions for Deepseek.""" # TODO(shuningjin): support hf->orbax(scan), b/457820372 if not saving_to_hf: raise NotImplementedError("This conversion only supports saving_to_hf") @@ -944,18 +929,19 @@ def DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN(): def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): - """Returns mapping from MaxText gpt-oss to Hugging Face weight paths. + """Generates mapping from MaxText gpt-oss to Hugging Face weight paths. - Handles the inhomogeneous scan block structure (inhomogeneous_layer_cycle_interval) + Returns: + dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter) or + `composite_mt_key` (a tuple of MaxText parameters). Values are Hugging Face parameter + names either a single string (unscanned form) or a list of strings (scanned form). - Handles N-to-1 mapping from maxtext to huggingface - - (GptOssMlp-wi_0, GptOssMlp-wi_1): mlp.experts.gate_up_proj - - (GptOssMlp-wi_0_bias, GptOssMlp-wi_1_bias): mlp.experts.gate_up_proj_bias + Notes: + - Handles the inhomogeneous scan block structure, based on `inhomogeneous_layer_cycle_interval` + - Handles `composite_mt_key`: multiple MaxText keys map to HF key(s) + - (GptOssMlp-wi_0, GptOssMlp-wi_1): mlp.experts.gate_up_proj + - (GptOssMlp-wi_0_bias, GptOssMlp-wi_1_bias): mlp.experts.gate_up_proj_bias """ - # TODO(shuningjin): add unscan support, b/459541579 - if not scan_layers: - raise NotImplementedError("Current gpt-oss mapping only supports scan_layers=True") - n_layers = config["num_hidden_layers"] # hf config layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval @@ -966,63 +952,82 @@ def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=Fals "params-decoder-logits_dense-kernel": "lm_head.weight", } - for block_idx in range(layer_cycle_interval): - # Identify all original HF layer indices that collapse into this block - hf_indices = list(range(block_idx, n_layers, layer_cycle_interval)) - prefix = f"params-decoder-layers-layers_{block_idx}" - - # Layer Norms - mapping[f"{prefix}-pre_self_attention_layer_norm-scale"] = [ - f"model.layers.{i}.input_layernorm.weight" for i in hf_indices - ] - mapping[f"{prefix}-post_self_attention_layer_norm-scale"] = [ - f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices - ] - - # GptOssAttention - mapping.update( - { - f"{prefix}-GptOssAttention-query-kernel": [f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices], - f"{prefix}-GptOssAttention-query-bias": [f"model.layers.{i}.self_attn.q_proj.bias" for i in hf_indices], - f"{prefix}-GptOssAttention-key-kernel": [f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices], - f"{prefix}-GptOssAttention-key-bias": [f"model.layers.{i}.self_attn.k_proj.bias" for i in hf_indices], - f"{prefix}-GptOssAttention-value-kernel": [f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices], - f"{prefix}-GptOssAttention-value-bias": [f"model.layers.{i}.self_attn.v_proj.bias" for i in hf_indices], - f"{prefix}-GptOssAttention-out-kernel": [f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices], - f"{prefix}-GptOssAttention-out-bias": [f"model.layers.{i}.self_attn.o_proj.bias" for i in hf_indices], - f"{prefix}-GptOssAttention-sinks": [f"model.layers.{i}.self_attn.sinks" for i in hf_indices], - } - ) - - # GptOssMlp - # 1. Gate/Router - mapping.update( - { - f"{prefix}-GptOssMlp-gate-kernel": [f"model.layers.{i}.mlp.router.weight" for i in hf_indices], - f"{prefix}-GptOssMlp-gate-bias": [f"model.layers.{i}.mlp.router.bias" for i in hf_indices], - } - ) - - # 2. Experts (Down Projection) - mapping.update( - { - f"{prefix}-GptOssMlp-wo": [f"model.layers.{i}.mlp.experts.down_proj" for i in hf_indices], - f"{prefix}-GptOssMlp-wo_bias": [f"model.layers.{i}.mlp.experts.down_proj_bias" for i in hf_indices], - } - ) + if scan_layers: + # Scan over blocks + for block_idx in range(layer_cycle_interval): + # Identify all original HF layer indices that collapse into this block + hf_indices = range(block_idx, n_layers, layer_cycle_interval) + prefix = f"params-decoder-layers-layers_{block_idx}" + block_mapping = { + # Layer Norms + f"{prefix}-pre_self_attention_layer_norm-scale": [ + f"model.layers.{i}.input_layernorm.weight" for i in hf_indices + ], + f"{prefix}-post_self_attention_layer_norm-scale": [ + f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices + ], + # GptOssAttention + f"{prefix}-GptOssAttention-query-kernel": [f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices], + f"{prefix}-GptOssAttention-query-bias": [f"model.layers.{i}.self_attn.q_proj.bias" for i in hf_indices], + f"{prefix}-GptOssAttention-key-kernel": [f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices], + f"{prefix}-GptOssAttention-key-bias": [f"model.layers.{i}.self_attn.k_proj.bias" for i in hf_indices], + f"{prefix}-GptOssAttention-value-kernel": [f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices], + f"{prefix}-GptOssAttention-value-bias": [f"model.layers.{i}.self_attn.v_proj.bias" for i in hf_indices], + f"{prefix}-GptOssAttention-out-kernel": [f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices], + f"{prefix}-GptOssAttention-out-bias": [f"model.layers.{i}.self_attn.o_proj.bias" for i in hf_indices], + f"{prefix}-GptOssAttention-sinks": [f"model.layers.{i}.self_attn.sinks" for i in hf_indices], + # GptOssMlp + # 1. Gate/Router + f"{prefix}-GptOssMlp-gate-kernel": [f"model.layers.{i}.mlp.router.weight" for i in hf_indices], + f"{prefix}-GptOssMlp-gate-bias": [f"model.layers.{i}.mlp.router.bias" for i in hf_indices], + # 2. Experts (Down Projection) + f"{prefix}-GptOssMlp-wo": [f"model.layers.{i}.mlp.experts.down_proj" for i in hf_indices], + f"{prefix}-GptOssMlp-wo_bias": [f"model.layers.{i}.mlp.experts.down_proj_bias" for i in hf_indices], + # 3. Experts (Gate/Up Fused Projection) + # `composite_mt_key`: Multiple MaxText keys map to HF key(s). + (f"{prefix}-GptOssMlp-wi_0", f"{prefix}-GptOssMlp-wi_1"): [ + f"model.layers.{i}.mlp.experts.gate_up_proj" for i in hf_indices + ], + (f"{prefix}-GptOssMlp-wi_0_bias", f"{prefix}-GptOssMlp-wi_1_bias"): [ + f"model.layers.{i}.mlp.experts.gate_up_proj_bias" for i in hf_indices + ], + } + mapping.update(block_mapping) - # 3. Experts (Gate/Up Fused Projection) - # N-to-1 mapping - mapping.update( - { - (f"{prefix}-GptOssMlp-wi_0", f"{prefix}-GptOssMlp-wi_1"): [ - f"model.layers.{i}.mlp.experts.gate_up_proj" for i in hf_indices - ], - (f"{prefix}-GptOssMlp-wi_0_bias", f"{prefix}-GptOssMlp-wi_1_bias"): [ - f"model.layers.{i}.mlp.experts.gate_up_proj_bias" for i in hf_indices - ], - } - ) + else: + # Unscan + for i in range(n_layers): + prefix = f"params-decoder-layers_{i}" + layer_mapping = { + # Layer Norms + f"{prefix}-pre_self_attention_layer_norm-scale": f"model.layers.{i}.input_layernorm.weight", + f"{prefix}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight", + # GptOssAttention + f"{prefix}-GptOssAttention-query-kernel": f"model.layers.{i}.self_attn.q_proj.weight", + f"{prefix}-GptOssAttention-query-bias": f"model.layers.{i}.self_attn.q_proj.bias", + f"{prefix}-GptOssAttention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight", + f"{prefix}-GptOssAttention-key-bias": f"model.layers.{i}.self_attn.k_proj.bias", + f"{prefix}-GptOssAttention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight", + f"{prefix}-GptOssAttention-value-bias": f"model.layers.{i}.self_attn.v_proj.bias", + f"{prefix}-GptOssAttention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight", + f"{prefix}-GptOssAttention-out-bias": f"model.layers.{i}.self_attn.o_proj.bias", + f"{prefix}-GptOssAttention-sinks": f"model.layers.{i}.self_attn.sinks", + # GptOssMlp + # 1. Gate/Router + f"{prefix}-GptOssMlp-gate-kernel": f"model.layers.{i}.mlp.router.weight", + f"{prefix}-GptOssMlp-gate-bias": f"model.layers.{i}.mlp.router.bias", + # 2. Experts (Down Projection) + f"{prefix}-GptOssMlp-wo": f"model.layers.{i}.mlp.experts.down_proj", + f"{prefix}-GptOssMlp-wo_bias": f"model.layers.{i}.mlp.experts.down_proj_bias", + # 3. Experts (Gate/Up Fused Projection) + # `composite_mt_key`: Multiple MaxText keys map to HF key(s). + (f"{prefix}-GptOssMlp-wi_0", f"{prefix}-GptOssMlp-wi_1"): f"model.layers.{i}.mlp.experts.gate_up_proj", + ( + f"{prefix}-GptOssMlp-wi_0_bias", + f"{prefix}-GptOssMlp-wi_1_bias", + ): f"model.layers.{i}.mlp.experts.gate_up_proj_bias", + } + mapping.update(layer_mapping) return mapping @@ -1030,24 +1035,16 @@ def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=Fals def GPT_OSS_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): """Transformation hooks for gpt-oss parameters. - Handles the inhomogeneous scan block structure (inhomogeneous_layer_cycle_interval) - - Handles N-to-1 mapping from maxtext to huggingface - - (GptOssMlp-wi_0, GptOssMlp-wi_1): mlp.experts.gate_up_proj - - (GptOssMlp-wi_0_bias, GptOssMlp-wi_1_bias): mlp.experts.gate_up_proj_bias + Notes: + - Handles the inhomogeneous scan block structure (inhomogeneous_layer_cycle_interval) + - Handles `composite_mt_key` where multiple MaxText keys map to HF key(s) + - (GptOssMlp-wi_0, GptOssMlp-wi_1): mlp.experts.gate_up_proj + - (GptOssMlp-wi_0_bias, GptOssMlp-wi_1_bias): mlp.experts.gate_up_proj_bias + - The composite keys are transformed via `interleave` function """ - # TODO(shuningjin): support hf->orbax(scan), b/459541579 - if not saving_to_hf: - raise NotImplementedError("Currently gpt-oss only supports saving_to_hf=True.") - # TODO(shuningjin): add unscan support, b/459541579 - if not scan_layers: - raise NotImplementedError("Currently gpt-oss only supports scan_layers=True.") def transpose(input_tensor, target_shape=None): - if saving_to_hf: - return input_tensor.T - else: - return input_tensor.T + return input_tensor.T def reshape_kernel(input_tensor, target_shape): """Reshapes and transposes kernel weights between MaxText and HF.""" @@ -1068,39 +1065,42 @@ def reshape_bias(input_tensor, target_shape=None): def interleave(input_tensor, target_shape=None): """ - N-to-1 mapping: maxtext (wi_0, wi_1) <-> hf (wi_0_1) - if saving_to_hf, input_tensor is a list of tensors + Handles `composite_mt_key`: maxtext (wi_0, wi_1) <-> hf (wi_0_1) + - if saving_to_hf: (wi_0, wi_1) -> wi_0_1 + - input_tensor is a list of two tensors, tensor ORDER must be same as key order + - return a single tensor + - otherwise: wi_0_1 -> (wi_0, wi_1) + - input_tensor is a single tensor + - return two tensors stack at LAST index -1, tensor ORDER must be same as key order """ if saving_to_hf: - # (wi_0, wi_1) -> wi_0_1 wi_0, wi_1 = input_tensor wi_0_1 = np.empty(target_shape, dtype=wi_0.dtype) wi_0_1[..., ::2] = wi_0 wi_0_1[..., 1::2] = wi_1 return wi_0_1 else: - # wi_0_1 -> (wi_0, wi_1) - # TODO(shuningjin): support hf->orbax(scan), b/459541579 - raise NotImplementedError + wi_0_1 = input_tensor + wi_0 = wi_0_1[..., ::2] + wi_1 = wi_0_1[..., 1::2] + return np.stack([wi_0, wi_1], axis=-1) - hooks = { - "params-decoder-logits_dense-kernel": transpose, - } - - # Scan over blocks + n_layers = config["num_hidden_layers"] # hf config layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval - for block_idx in range(layer_cycle_interval): - prefix = f"params-decoder-layers-layers_{block_idx}" + + hooks = {"params-decoder-logits_dense-kernel": transpose} + + indices = range(layer_cycle_interval) if scan_layers else range(n_layers) + for idx in indices: + prefix = f"params-decoder-layers-layers_{idx}" if scan_layers else f"params-decoder-layers_{idx}" # Attention Kernels & Biases for key in ["query", "key", "value"]: hooks[f"{prefix}-GptOssAttention-{key}-kernel"] = reshape_kernel hooks[f"{prefix}-GptOssAttention-{key}-bias"] = reshape_bias - hooks[f"{prefix}-GptOssAttention-out-kernel"] = reshape_kernel - # MLP Kernels & Biases hooks[f"{prefix}-GptOssMlp-gate-kernel"] = transpose - # Experts (Gate/Up Fused Projection), N-to-1 mapping + # `composite_mt_key`: A hook for combining multiple MaxText params. hooks[(f"{prefix}-GptOssMlp-wi_0", f"{prefix}-GptOssMlp-wi_1")] = interleave hooks[(f"{prefix}-GptOssMlp-wi_0_bias", f"{prefix}-GptOssMlp-wi_1_bias")] = interleave @@ -1207,8 +1207,9 @@ def LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=Fals into a single param. Defaults to False. Returns: - dict: A mapping from MaxText parameter names to HF parameter names (str) - or lists of names (if scan_layers=True). + dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names). + Values are either a single string (unscanned form) or a list of strings + (scanned form) for stacked layers when `scan_layers=True`. """ n_layers = config["num_hidden_layers"] @@ -1284,63 +1285,46 @@ def LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=Fals nlayers = config["num_hidden_layers"] def scale_query_layer(input_tensor, target_shape): - def to_hf(): + if saving_to_hf: depth_scale = np.dtype("float32").type(np.sqrt(config["head_dim"])) - original_dtype = input_tensor.dtype output_tensor = input_tensor.astype(np.float32) * depth_scale return output_tensor.astype(original_dtype) - - def from_hf(): + else: depth_scale = np.dtype("float32").type(1 / np.sqrt(config["head_dim"])) - original_dtype = input_tensor.dtype output_tensor = input_tensor.astype(np.float32) * depth_scale return output_tensor.astype(original_dtype) + def adjust_rope(input_tensor, target_shape): + arr = input_tensor if saving_to_hf: - return to_hf() + # Convert from MaxText's interleaved layout to HF's concatenated layout + evens = arr[..., ::2] + odds = arr[..., 1::2] + return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1) else: - return from_hf() - - def adjust_rope(input_tensor, target_shape): - def from_hf(arr): - """Convert from HF's concatenated layout to MaxText's interleaved layout""" + # Convert from HF's concatenated layout to MaxText's interleaved layout half_dim = arr.shape[-1] // 2 first_half = arr[..., :half_dim] second_half = arr[..., half_dim:] return jax.numpy.stack([first_half, second_half], axis=-1).reshape(arr.shape) - def to_hf(arr): - """Convert from MaxText's interleaved layout to HF's concatenated layout""" - evens = arr[..., ::2] - odds = arr[..., 1::2] - return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1) - - if saving_to_hf: - return to_hf(input_tensor) - else: - return from_hf(input_tensor) - def reshape_kernel(input_tensor, target_shape): - def to_hf(): + if saving_to_hf: flipped_target_shape = np.flip(np.array(target_shape)) return input_tensor.reshape(flipped_target_shape).transpose() - - def from_hf(): - return input_tensor.transpose().reshape(target_shape) - - if saving_to_hf: - return to_hf() else: - return from_hf() - - query_hooks = [reshape_kernel, adjust_rope, scale_query_layer] - key_hooks = [reshape_kernel, adjust_rope] + return input_tensor.transpose().reshape(target_shape) + # caveat: hook order does affect result + # to_huggingface + query_hook_chain = [scale_query_layer, adjust_rope, reshape_kernel] + key_hook_chain = [adjust_rope, reshape_kernel] + # to_maxtext if not saving_to_hf: - query_hooks.reverse() - key_hooks.reverse() + query_hook_chain.reverse() + key_hook_chain.reverse() hook_fns = {} @@ -1349,8 +1333,8 @@ def from_hf(): if scan_layers: hook_fns = { **hook_fns, - "params-decoder-layers-self_attention-query-kernel": query_hooks, - "params-decoder-layers-self_attention-key-kernel": key_hooks, + "params-decoder-layers-self_attention-query-kernel": query_hook_chain, + "params-decoder-layers-self_attention-key-kernel": key_hook_chain, "params-decoder-layers-self_attention-value-kernel": reshape_kernel, "params-decoder-layers-self_attention-out-kernel": reshape_kernel, "params-decoder-layers-mlp-wi_0-kernel": reshape_kernel, @@ -1359,8 +1343,8 @@ def from_hf(): } else: for layer_idx in range(nlayers): - hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-query-kernel"] = query_hooks - hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-key-kernel"] = key_hooks + hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-query-kernel"] = query_hook_chain + hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-key-kernel"] = key_hook_chain hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-value-kernel"] = reshape_kernel hook_fns[f"params-decoder-layers_{layer_idx}-self_attention-out-kernel"] = reshape_kernel hook_fns[f"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel"] = reshape_kernel @@ -1426,7 +1410,13 @@ def transform_query_kernel(arr): def MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """ - Returns the mapping of parameter names from MaxText to Hugging Face for Mixtral. + Generates the mapping of parameter names from MaxText to Hugging Face for Mixtral. + + Returns: + dict: A mapping where keys are `atomic_mt_key` (single MaxText parameter names). Values + are Hugging Face parameter names in one of four forms: unscanned string, + scanned list of strings, unscanned with expert stacking (list of strings), + or scanned with expert stacking (nested list of strings). """ mapping = {} @@ -1543,9 +1533,12 @@ def scale_query_layer(input_tensor, target_shape): depth_scale = np.dtype("float32").type(1 / np.sqrt(maxtext_config.head_dim)) return (input_tensor * depth_scale).astype(input_tensor.dtype) + # hook order does not affect result + query_hook_chain = [reshape_and_transpose_attention, scale_query_layer] + if scan_layers: plan = [ - ("params-decoder-layers-self_attention-query-kernel", [reshape_and_transpose_attention, scale_query_layer]), + ("params-decoder-layers-self_attention-query-kernel", query_hook_chain), ("params-decoder-layers-self_attention-key-kernel", reshape_and_transpose_attention), ("params-decoder-layers-self_attention-value-kernel", reshape_and_transpose_attention), ("params-decoder-layers-self_attention-out-kernel", reshape_and_transpose_attention), @@ -1556,7 +1549,7 @@ def scale_query_layer(input_tensor, target_shape): ] else: plan = [ - ("params-decoder-layers_{i}-self_attention-query-kernel", [reshape_and_transpose_attention, scale_query_layer]), + ("params-decoder-layers_{i}-self_attention-query-kernel", query_hook_chain), ("params-decoder-layers_{i}-self_attention-key-kernel", reshape_and_transpose_attention), ("params-decoder-layers_{i}-self_attention-value-kernel", reshape_and_transpose_attention), ("params-decoder-layers_{i}-self_attention-out-kernel", reshape_and_transpose_attention), diff --git a/src/MaxText/utils/ckpt_conversion/utils/utils.py b/src/MaxText/utils/ckpt_conversion/utils/utils.py index 24eea328d9..cd51b6aec1 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/utils.py +++ b/src/MaxText/utils/ckpt_conversion/utils/utils.py @@ -93,18 +93,114 @@ def _get_local_directory(output_dir: str) -> str: return local_dir +def validate_and_filter_param_map_keys(param_map_keys, maxtext_state_keys): + """Validates param_mapping coverage and filters unused keys, for to_maxtext and to_huggingface. + + Preprocess maxtext keys for transformation. + - Ensures every MaxText checkpoint key (`maxtext_state_keys`) is covered by + the flattened param_mapping. + - Keys in the param_mapping that are not present in the checkpoint (common for + multi-variant maps like gemma3, qwen3, deepseek) are skipped. + + Args: + param_map_keys: MaxText keys from the `PARAM_MAPPING`. These can be: + - `atomic_mt_key`: A single string representing one MaxText parameter that map to HF parameter(s). + - `composite_mt_key`: A tuple of strings representing multiple MaxText parameters that map to HF parameter(s). + maxtext_state_keys: Set of MaxText keys loaded from the Orbax checkpoint. + + Returns: + A list of 'filtered' mapping keys (strings or tuples) that are fully present + and valid based on `maxtext_state_keys`. + + Raises: + ValueError: If `maxtext_state_keys` is NOT a subset of the flattened + `param_map_keys`. + """ + flattened_map_keys = set() + for key in param_map_keys: + if isinstance(key, tuple): + flattened_map_keys.update(key) + else: + flattened_map_keys.add(key) + + # 1 Validate: every maxtext state key must be covered by param map + missing_keys = maxtext_state_keys - flattened_map_keys + if missing_keys: + raise ValueError( + "maxtext_state_dict must be a subset of flattened param_map" + + f"\nparam map\n{param_map_keys}" + + f"\nmaxtext:\n{maxtext_state_keys}" + ) + + # 2 Filter: param map may have extra keys + extra_keys = flattened_map_keys - maxtext_state_keys + if extra_keys: + max_logging.log(f"Warning: extra keys in param_map are skipped: {extra_keys}") + + # skip extra keys in param map + filtered_map_keys = [] + for key in param_map_keys: + if (isinstance(key, str) and key in maxtext_state_keys) or ( + isinstance(key, tuple) and all(k in maxtext_state_keys for k in key) + ): + filtered_map_keys.append(key) + return filtered_map_keys + + +def apply_hook_fns(weight, target_shape, hook_fns): + """Apply hook functions, essential for to_maxtext and to_huggingface""" + # If hook is unsepecified, use identity + if hook_fns is None: + return weight + if not isinstance(hook_fns, list): + hook_fns = [hook_fns] + # Apply a list of hooks, be careful of order + for hook_fn in hook_fns: + weight = hook_fn(weight, target_shape) + return weight + + +def convert_jax_weight_to_numpy(weight: "jax.Array", dtype_str: None | str = None) -> np.ndarray: + """Converts a JAX array to a NumPy array with the specified dtype, used in to_huggingface. + + Args: + weight: The input JAX array, potentially sharded across devices. + dtype_str: The target NumPy dtype as a string (e.g., 'float32', 'bfloat16'). + If None, the dtype of the input JAX array is preserved. Defaults to None. + + Returns: + A NumPy array containing the data from `weight`, cast to `dtype_str` if provided. + """ + final_dtype_str = str(weight.dtype) if dtype_str is None else dtype_str + # JAX dtypes like 'bfloat16', 'float32' are understood by np.dtype() + target_np_dtype = np.dtype(final_dtype_str) + expected_shape = weight.shape + + # Gather the array across devices if it's sharded. + # process_allgather typically returns the array on the host. + weight = multihost_utils.process_allgather(weight) + + # Convert JAX array to NumPy array. + np_array = np.array(weight) + + # Cast to the target NumPy dtype if it's different. + if np_array.dtype != target_np_dtype: + np_array = np_array.astype(target_np_dtype) + + return np_array.reshape(expected_shape) # Reshape for safety, though usually preserved. + + def _process(hf_path, processed_slice, output_weights, current_hook_fns, hf_shape_map): - """Applies hooks, converts a JAX slice to NumPy, and appends it to the output list.""" + """Applies hooks, converts a JAX slice to NumPy, and appends it to the output list, used in to_huggingface""" if hf_path not in hf_shape_map: raise ValueError(f"HF path '{hf_path}' not found in hf_shape_map.") target_hf_shape = hf_shape_map[hf_path] + # If hook is unsepecified, use identity if current_hook_fns: - # otherwise identity processed_slice = apply_hook_fns(processed_slice, target_hf_shape, current_hook_fns) numpy_slice = convert_jax_weight_to_numpy(processed_slice).squeeze() - assert len(target_hf_shape) == len( - numpy_slice.shape - ), f"shape mismatch {len(target_hf_shape)} and {len(numpy_slice.shape)}" + if numpy_slice.shape != target_hf_shape: + raise ValueError(f"Shape mismatch for {hf_path}: Expect {target_hf_shape}, got {numpy_slice.shape}") output_weights.append((hf_path, numpy_slice)) @@ -116,22 +212,23 @@ def process_maxtext_param( hf_shape_map: dict[str, Any], maxtext_config: Any, ) -> list[tuple[str, np.ndarray]]: - """Processes a single MaxText parameter (or a group of parameters) for conversion to Hugging Face format. + """Processes a single MaxText parameter (or a group of parameters) for conversion, used in to_huggingface. - This function is responsible for taking a MaxText parameter (which might be - a single tensor or a list of tensors for N-to-1 mappings) and transforming + This function is responsible for taking a MaxText parameter and transforming it into one or more Hugging Face compatible parameters. It handles various - scenarios including: - - 1-to-1 mappings (single MaxText param to single HF param). - - N-to-1 mappings (multiple MaxText params combined into a single HF param). - - Stacked MaxText parameters (e.g., scanned layers or MoE experts) that need - to be unstacked into individual Hugging Face parameters. + scenarios based on + - the MaxText key form (`atomic_mt_key` or `composite_mt_key`) + - and the Hugging Face value form (unscanned string, scanned list of strings, + unscanned with expert stacking, or scanned with expert stacking). + Note: We assume composite_mt_key can only occur for unscanned/scanned HF keys, but not those with expert stacking. Args: - maxtext_param_key: The key (or tuple of keys for N-to-1 mappings) identifying - the MaxText parameter(s) being processed. + maxtext_param_key: The key identifying the MaxText parameter(s). Can be + an `atomic_mt_key` (str) or a `composite_mt_key` (tuple of str) mapping + to HF parameter(s). maxtext_param_weight: The actual weight(s) of the MaxText parameter(s). - This can be a single `jax.Array` or a list of `jax.Array` for N-to-1 mappings. + This can be a single `jax.Array` for an `atomic_mt_key` or a list of + `jax.Array` for a `composite_mt_key`. param_map: A dictionary mapping MaxText parameter keys to their corresponding Hugging Face target path(s). hook_fn_map: A dictionary mapping MaxText parameter keys to transformation @@ -160,7 +257,7 @@ def process_maxtext_param( # This list will store tuples of (hf_path, hf_weight) output_weights = [] - # Case 1: Unscan + # Case 1: Unscanned if not isinstance(hf_target_paths, list): max_logging.log("\tunscan") hf_path = hf_target_paths @@ -169,98 +266,62 @@ def process_maxtext_param( # Stacked MaxText weight # This now handles three cases: - # 2. Scanned MoE layers (2D list of targets from a tensor stacked on expert and layer axes) + # 2. Standard scanned layers (1D list of targets from a tensor stacked only on the layer axis) # 3. Unscanned MoE layers (1D list of targets from a tensor stacked only on the expert axis) - # 4. Standard scanned layers (1D list of targets from a tensor stacked only on the layer axis) - is_scanned_moe_layer = isinstance(hf_target_paths[0], list) - - if is_scanned_moe_layer: - max_logging.log("\tscan moe") - # Case 2: Scanned MoE layer, e.g., from 'layers-moe_block-wi_0'. - # The tensor is stacked on expert and layer axes. We slice experts first, then layers. - # MaxText format is (experts, layers, ...), so expert axis is 0, layer axis is 1. - expert_axis_to_slice = 0 - - # Outer loop for experts - for expert_idx, expert_paths_for_layer in enumerate(hf_target_paths): - # Slice along the expert axis to get the tensor for the current expert across all layers. - expert_tensor_slice = jax.lax.index_in_dim( - maxtext_param_weight, expert_idx, axis=expert_axis_to_slice, keepdims=False - ) - # Inner loop for layers - for layer_idx, hf_path in enumerate(expert_paths_for_layer): - # Slice the expert tensor along the layer axis to get the final individual weight. - # axis is 0 on the new sliced tensor - layer_tensor_slice = jax.lax.index_in_dim(expert_tensor_slice, layer_idx, axis=0, keepdims=False) - _process(hf_path, layer_tensor_slice, output_weights, current_hook_fns, hf_shape_map) + # 4. Scanned MoE layers (2D list of targets from a tensor stacked on expert and layer axes) + + if not isinstance(hf_target_paths[0], list): + # Case 2 or 3: The source tensor is stacked on a single axis. + # i.e., hf_target_paths is an (un-nested) list + # We determine if it's standard scanned (stack on layer axis) or unscanned MoE (stack on expert axis). + if maxtext_config.scan_layers: + max_logging.log("\tscan") + # Case 2: Standard scanned layer. + # The tensor is stacked ONLY on the layer axis. + axis_to_slice = maxtext_config.param_scan_axis + else: + max_logging.log("\tunscan moe") + # Case 3: Unscanned MoE layer, e.g., from 'layers_0-moe_block-wi_0'. + # The tensor is stacked ONLY on the expert axis. Assuming expert is axis 0. + axis_to_slice = 0 + + # Iterate through the slices of the MaxText weight along the determined stacking axis. + # Handles MaxText key forms (`atomic_mt_key` and `composite_mt_key`) + for i, hf_path in enumerate(hf_target_paths): + if isinstance(maxtext_param_weight, list): + # This handles `composite_mt_key` mappings where `maxtext_param_weight` is a list of tensors. + # Each tensor in the list is sliced independently along the `axis_to_slice`. + weight_slice = [jax.lax.index_in_dim(x, i, axis=axis_to_slice, keepdims=False) for x in maxtext_param_weight] + else: + # For `atomic_mt_key` mappings, slice the single MaxText tensor. + weight_slice = jax.lax.index_in_dim(maxtext_param_weight, i, axis=axis_to_slice, keepdims=False) + _process(hf_path, weight_slice, output_weights, current_hook_fns, hf_shape_map) return output_weights - # Case 3 or 4: The source tensor is stacked on a single axis. - # We determine if it's an unscanned MoE (expert axis) or standard scanned (layer axis). - # `w` is needed for weights, and except for gate. - # Gate values are stack in layers only, but weights are stack in both expert and layer. - moe_block_list = ["moe_block", "MoeBlock_0-w"] - is_unscanned_moe = any(block in maxtext_param_key for block in moe_block_list) and any( - f"_{i}-" in maxtext_param_key for i in range(maxtext_config.base_num_decoder_layers) - ) - - if is_unscanned_moe: - max_logging.log("\tunscan moe") - # Case 3: Unscanned MoE layer, e.g., from 'layers_0-moe_block-wi_0'. - # The tensor is stacked ONLY on the expert axis. Assuming expert is axis 0. - axis_to_slice = 0 - else: - max_logging.log("\tscan") - # Case 4: Standard scanned layer. - # The tensor is stacked ONLY on the layer axis. - axis_to_slice = maxtext_config.param_scan_axis - - # Iterate through the slices of the MaxText weight along the determined stacking axis. - for i, hf_path in enumerate(hf_target_paths): - if isinstance(maxtext_param_weight, list): - # This handles N-to-1 mappings where `maxtext_param_weight` is a list of tensors. - # Each tensor in the list is sliced independently along the `axis_to_slice`. - weight_slice = [jax.lax.index_in_dim(x, i, axis=axis_to_slice, keepdims=False) for x in maxtext_param_weight] - else: - # For 1-to-1 mappings, slice the single MaxText tensor. - weight_slice = jax.lax.index_in_dim(maxtext_param_weight, i, axis=axis_to_slice, keepdims=False) - _process(hf_path, weight_slice, output_weights, current_hook_fns, hf_shape_map) + # Multi axis stacked: isinstance(hf_target_paths[0], list) + max_logging.log("\tscan moe") + # Case 4: Scanned MoE layer, e.g., from 'layers-moe_block-wi_0'. + # The tensor is stacked on expert and layer axes. We slice experts first, then layers. + # MaxText format is (experts, layers, ...), so expert axis is 0, layer axis is 1. + expert_axis_to_slice = 0 + + # Outer loop for experts + for expert_idx, expert_paths_for_layer in enumerate(hf_target_paths): + # Slice along the expert axis to get the tensor for the current expert across all layers. + expert_tensor_slice = jax.lax.index_in_dim( + maxtext_param_weight, expert_idx, axis=expert_axis_to_slice, keepdims=False + ) + # Inner loop for layers + for layer_idx, hf_path in enumerate(expert_paths_for_layer): + # Slice the expert tensor along the layer axis to get the final individual weight. + # axis is 0 on the new sliced tensor + layer_tensor_slice = jax.lax.index_in_dim(expert_tensor_slice, layer_idx, axis=0, keepdims=False) + _process(hf_path, layer_tensor_slice, output_weights, current_hook_fns, hf_shape_map) return output_weights -def convert_jax_weight_to_numpy(weight: "jax.Array", dtype_str: None | str = None) -> np.ndarray: - """Converts a JAX array to a NumPy array with the specified dtype.""" - final_dtype_str = str(weight.dtype) if dtype_str is None else dtype_str - # JAX dtypes like 'bfloat16', 'float32' are understood by np.dtype() - target_np_dtype = np.dtype(final_dtype_str) - expected_shape = weight.shape - - # Gather the array across devices if it's sharded. - # process_allgather typically returns the array on the host. - weight = multihost_utils.process_allgather(weight) - - # Convert JAX array to NumPy array. - np_array = np.array(weight) - - # Cast to the target NumPy dtype if it's different. - if np_array.dtype != target_np_dtype: - np_array = np_array.astype(target_np_dtype) - - return np_array.reshape(expected_shape) # Reshape for safety, though usually preserved. - - -def apply_hook_fns(weight, target_shape, hook_fns): - if hook_fns is None: - return weight - if not isinstance(hook_fns, list): - hook_fns = [hook_fns] - for hook_fn in hook_fns[::-1]: - weight = hook_fn(weight, target_shape) - return weight - - def create_huggingface_hub_repo_if_not_exist(repo_id, repo_type): if not repo_exists(repo_id, repo_type=repo_type): api = HfApi() @@ -689,7 +750,7 @@ def print_ram_usage(stage=""): def get_hf_model(model_id: str, token: str): - """Loads the HuggingFace model based on model_id (Eager mode only).""" + """Loads the HuggingFace model based on model_id (Eager mode only), used in to_maxtext""" if model_id in ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]: from transformers import Qwen3OmniMoeForConditionalGeneration # pylint: disable=import-outside-toplevel diff --git a/tests/forward_pass_logit_checker.py b/tests/forward_pass_logit_checker.py index cc3c529908..3b15e5d1f1 100644 --- a/tests/forward_pass_logit_checker.py +++ b/tests/forward_pass_logit_checker.py @@ -381,10 +381,21 @@ def main(config, test_args): # pylint: disable=W0621 """Comparing maxtext model with HF model on-the-fly""" if test_args.hf_model_path == "": raise ValueError("run_hf_model requires hf_model_path") - hf_model = AutoModelForCausalLM.from_pretrained(test_args.hf_model_path, dtype=torch.bfloat16) - tokenizer = AutoTokenizer.from_pretrained(test_args.hf_model_path) - pad_token_models = ["Llama-3.1", "Mixtral-8x"] - if any(model in test_args.hf_model_path for model in pad_token_models): + + hf_token = config.hf_access_token + hf_model = AutoModelForCausalLM.from_pretrained(test_args.hf_model_path, dtype=torch.bfloat16, token=hf_token) + + if os.path.isdir(test_args.hf_model_path): + # local hf directory may not contain tokenizer, read from remote tokenizer + tokenizer_path = config.tokenizer_path + else: + tokenizer_path = test_args.hf_model_path + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=hf_token) + + # maxtext model prefix, use eos token as pad token + pad_token_prefixes = ["llama3.1", "mixtral"] + if any(config.model_name.startswith(prefix) for prefix in pad_token_prefixes): tokenizer.pad_token = tokenizer.eos_token init_rng = jax.random.PRNGKey(config.init_weights_seed)