From 01f665d7d87956101b728a719db9013058636433 Mon Sep 17 00:00:00 2001 From: Gagik Amirkhanyan Date: Mon, 29 Dec 2025 22:45:34 -0800 Subject: [PATCH] Add soft distillation training script and configuration. --- src/MaxText/configs/distillation.yml | 61 ++ src/MaxText/configs/types.py | 19 + src/MaxText/distillation/__init__.py | 13 + src/MaxText/distillation/train_distill.py | 581 ++++++++++++++++++ .../integration/tunix/tunix_adapter.py | 4 +- tests/train_distill_test.py | 207 +++++++ 6 files changed, 883 insertions(+), 2 deletions(-) create mode 100644 src/MaxText/configs/distillation.yml create mode 100644 src/MaxText/distillation/__init__.py create mode 100644 src/MaxText/distillation/train_distill.py create mode 100644 tests/train_distill_test.py diff --git a/src/MaxText/configs/distillation.yml b/src/MaxText/configs/distillation.yml new file mode 100644 index 0000000000..cf024fec29 --- /dev/null +++ b/src/MaxText/configs/distillation.yml @@ -0,0 +1,61 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Soft Distillation Configuration + +# Inherit MaxText defaults +base_config: "base.yml" + +# --- Student Specifics --- +# These are passed as kwargs to the Student config initialization +student_overrides: + model_name: "llama3.1-8b" + +# --- Teacher Specifics --- +# These are passed as kwargs to the Teacher config initialization +teacher_overrides: + model_name: "llama3.1-8b" + +# --- Distillation Loss --- +distill_alpha: 0.5 +distill_temperature: 1.0 + +# --- Dataset & Tokenizer --- +hf_path: "OptimalScale/ClimbMix" +dataset_type: "hf" +tokenizer_path: "meta-llama/Llama-3.1-8B" +tokenizer_type: "huggingface" + +# dataset_path: "gs://max-datasets-rogue" +# tokenizer_path: "src/MaxText/assets//tokenizer_llama3.tiktoken" +# tokenizer_type: "tiktoken" + +max_target_length: 2048 + +# --- Training Loop --- +steps: 200000 +checkpoint_period: 2000 +log_period: 10 +save_checkpoint_on_completion: True + +# --- Batch Size Strategy --- +# Global Batch Size = per_device_batch_size * num_devices * gradient_accumulation_steps +per_device_batch_size: 2 +gradient_accumulation_steps: 8 + +# --- Learning Rate Schedule --- +learning_rate: 2.0e-4 +learning_rate_schedule_steps: 200000 +warmup_steps_fraction: 0.1 +cosine_learning_rate_final_fraction: 0.1 \ No newline at end of file diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 63f4b6c999..ba0da45538 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -967,6 +967,24 @@ class FineTuning(BaseModel): use_grpo: None | bool = Field(None, description="If True, enables Group Relative Policy Optimization.") +class Distillation(BaseModel): + """Configuration for Knowledge Distillation.""" + + # --- Overrides --- + # These dictionaries allow flexible configuration injection for Student/Teacher + # without needing to duplicate the entire MaxText schema here. + student_overrides: dict[str, Any] = Field( + default_factory=dict, description="Overrides specific to the Student model (e.g., {'num_query_heads': 16})." + ) + teacher_overrides: dict[str, Any] = Field( + default_factory=dict, description="Overrides specific to the Teacher model (e.g., {'num_query_heads': 64})." + ) + + # --- Loss Params --- + distill_alpha: float = Field(0.5, description="Weight for the distillation loss component.") + distill_temperature: float = Field(1.0, description="Temperature for distillation softening.") + + class TrainingLoop(BaseModel): """Configuration for the main training loop, evaluation, and reproducibility.""" @@ -1634,6 +1652,7 @@ class MaxTextConfig( AdamW, Muon, FineTuning, + Distillation, # Reinforcement Learning RLHardware, VLLM, diff --git a/src/MaxText/distillation/__init__.py b/src/MaxText/distillation/__init__.py new file mode 100644 index 0000000000..2237c9162e --- /dev/null +++ b/src/MaxText/distillation/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/MaxText/distillation/train_distill.py b/src/MaxText/distillation/train_distill.py new file mode 100644 index 0000000000..c90f74ebf9 --- /dev/null +++ b/src/MaxText/distillation/train_distill.py @@ -0,0 +1,581 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Distillation Trainer for MaxText + Tunix. + +This script implements the "Post-Pruning Recovery" distillation process: recovering model quality +via soft distillation from a Teacher model. It leverages the Tunix Distillation library +for the training loop and loss calculation, while using MaxText for efficient +TPU model execution and data loading. + +Architecture Overview: +---------------------- +1. **Dual Model Loading**: Uniquely, this script initializes two distinct MaxText models: + - Student: The model being trained (can be pruned/smaller). + - Teacher: The frozen reference model (usually larger or same size). + +2. **Configuration Isolation**: To support different architectures (e.g., a pruned Student + vs. a full Teacher), we use `pyconfig` to generate two separate configuration objects + derived from the same base YAML but applied with different overrides. + +3. **Tunix Integration**: We wrap the MaxText models in `TunixMaxTextAdapter` to expose + a standard interface (call signature) that the Tunix `DistillationTrainer` expects. +""" + +import os +from typing import Any, Iterator, Sequence, Dict, Tuple + +from absl import app +import flax +from flax import nnx +from flax.linen import partitioning as nn_partitioning +import jax +import jax.numpy as jnp +import numpy as np +import optax +from orbax import checkpoint + +# MaxText Imports +from MaxText import max_logging +from MaxText import maxtext_utils +from MaxText import model_creation_utils +from MaxText import optimizers +from MaxText import pyconfig +from MaxText import tokenizer +from MaxText import train_utils +from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter + +# Tunix Imports +from tunix.distillation import distillation_trainer +from tunix.distillation.strategies import logit +from tunix.sft import metrics_logger +from tunix.sft import profiler + + +# ----------------------------------------------------------------------------- +# Distillation Optimizer with cosine decay and warmup +# ----------------------------------------------------------------------------- + + +def get_distillation_optimizer(config, max_train_steps): + """Creates a custom optimizer for distillation that enables Learning Rate logging. + + This function constructs an optax optimizer using standard MaxText settings but + wraps it with `optax.inject_hyperparams`. This wrapper is strictly required + by the Tunix `PeftTrainer` to log the learning rate to TensorBoard; without it, + the trainer cannot find the LR in the optimizer state. + + Args: + config: The HyperParameters object containing optimizer settings (e.g., + `learning_rate`, `adam_b1`, `opt_type`, `gradient_clipping_threshold`). + max_train_steps: The total number of training steps, used to calculate + the warmup and cosine decay schedule. + + Returns: + An optax optimizer that: + 1. Uses the optimizer type specified in `config.opt_type` (AdamW, SGD, etc.). + 2. Follows the MaxText cosine decay schedule. + 3. Applies gradient clipping if configured. + 4. Exposes the learning rate as a hyperparameter in the state for logging. + """ + # Check for unsupported Muon optimizer + if config.opt_type == "muon": + raise ValueError("Muon optimizer is not currently supported in distillation mode.") + + # 1. Define Schedule + schedule = optax.schedules.warmup_cosine_decay_schedule( + init_value=0.0, + peak_value=config.learning_rate, + warmup_steps=int(config.warmup_steps_fraction * max_train_steps), + decay_steps=max_train_steps, + end_value=config.cosine_learning_rate_final_fraction * config.learning_rate, + ) + + # 2. Define Factory (Required for inject_hyperparams) + def optimizer_factory(learning_rate): + # Reuse MaxText's standard logic to create the base optimizer. + # We pass 'learning_rate' (which is the injected schedule) directly. + opt = optimizers.get_optimizer(config, learning_rate, model=None) + + # Apply Gradient Clipping + if config.gradient_clipping_threshold > 0: + opt = optax.chain( + optax.clip_by_global_norm(max_norm=config.gradient_clipping_threshold), + opt, + ) + return opt + + # 3. Create Injectable Optimizer + # This wraps the factory so 'learning_rate' sits at the top level of the state + optimizer = optax.inject_hyperparams(optimizer_factory)(learning_rate=schedule) + + return optimizer + + +# ----------------------------------------------------------------------------- +# Custom Data Structures & Strategies +# ----------------------------------------------------------------------------- + + +@flax.struct.dataclass(frozen=True) +class MaxTextTrainingInput(distillation_trainer.TrainingInput): + """Extended TrainingInput dataclass to carry MaxText-specific fields. + + Attributes: + positions: Position indices for the tokens (for RoPE). + decoder_segment_ids: Segment IDs for packed sequences (0=padding, 1+=examples). + targets: Ground truth target tokens (used for loss calculation and logging). + """ + + positions: Any = None + decoder_segment_ids: Any = None + targets: Any = None + + +class MonitoredLogitStrategy(logit.LogitStrategy): + """Logit Strategy that returns detailed metrics for TensorBoard.""" + + def compute_loss( + self, + student_output: jax.Array, + teacher_output: jax.Array, + labels: jax.Array, + ) -> Tuple[jax.Array, Dict[str, jax.Array]]: + """Computes Loss and Auxiliary Metrics.""" + # Calculate Distillation Loss (KL Divergence) + # Scale logits by temperature T for soft targets + # We use explicit float32 casting for stability in loss calculation + s_logits = student_output.astype(jnp.float32) + t_logits = teacher_output.astype(jnp.float32) + + log_student_probs_temp = jax.nn.log_softmax(s_logits / self.temperature, axis=-1) + teacher_probs_temp = jax.nn.softmax(t_logits / self.temperature, axis=-1) + + # KL(Teacher || Student) + kl_div = optax.kl_divergence(log_student_probs_temp, teacher_probs_temp) + + # Scale gradients by T^2 (Hinton et al.) + soft_loss = jnp.mean(kl_div) * (self.temperature**2) + + # 1. Student Hard Loss (Existing) + ce_loss_student = optax.softmax_cross_entropy(logits=s_logits, labels=labels) + hard_loss = jnp.mean(ce_loss_student) + + # 2. Teacher Hard Loss (For Verification) + ce_loss_teacher = optax.softmax_cross_entropy(logits=t_logits, labels=labels) + teacher_hard_loss = jnp.mean(ce_loss_teacher) + + # 3. Combine losses + total_loss = (self.alpha * soft_loss) + ((1.0 - self.alpha) * hard_loss) + + # 4. Return Loss AND Metrics + metrics = { + "distill/soft_loss": soft_loss, + "distill/hard_loss": hard_loss, + "distill/kl_div": jnp.mean(kl_div), + "distill/teacher_loss": teacher_hard_loss, + } + return total_loss, metrics + + def compute_eval_loss( + self, + student_output: jax.Array, + labels: jax.Array, + ) -> Tuple[jax.Array, Dict[str, jax.Array]]: + """Computes Eval Loss and returns empty aux dict (required for consistency).""" + # Parent logic for task loss + # We re-implement simple CE here to ensure float32 casting + s_logits = student_output.astype(jnp.float32) + ce_loss = optax.softmax_cross_entropy(logits=s_logits, labels=labels) + task_loss = jnp.mean(ce_loss) + + # Must return a tuple because _has_aux=True expects it + return task_loss, {} + + +def _log_config_details(config: pyconfig.HyperParameters, label: str) -> None: + """Logs detailed architecture configuration for verification. + + Args: + config: The HyperParameters object to inspect. + label: A string label (e.g., 'Student', 'Teacher') for the log output. + """ + kv_heads = getattr(config, "num_kv_heads", config.num_query_heads) + max_logging.log(f"--- {label} Configuration ---") + max_logging.log(f" Model Name: {config.model_name}") + max_logging.log( + f" Dimensions: {config.num_decoder_layers} Layers, " f"{config.emb_dim} Emb Dim, {config.head_dim} Head Dim" + ) + max_logging.log(f" Attention Heads: {config.num_query_heads} Query, {kv_heads} KV") + max_logging.log(f" Vocab Size: {config.vocab_size}") + max_logging.log(f" Checkpoint: {config.load_parameters_path}") + + +class MaxTextDistillationTrainer(distillation_trainer.DistillationTrainer): + """Custom Trainer to preserve MaxText fields and log Teacher metrics. + + This class overrides `_prepare_inputs` to ensure MaxText-specific fields + (positions, segment_ids) are passed to the model. + """ + + def _prepare_inputs(self, input_data: MaxTextTrainingInput) -> MaxTextTrainingInput: + """Prepares inputs for the student model and runs the teacher model. + + This function generates the "Soft Targets" (logits) from the Teacher model + that the Student will learn to mimic. + + Args: + input_data: The batch of data from the iterator. + + Returns: + A new MaxTextTrainingInput containing the Teacher's outputs (logits). + """ + # 1. Generate inputs dictionary for the Teacher model + inputs = self.gen_model_input_fn(input_data)["inputs"] + + if self._mode == metrics_logger.Mode.EVAL: + teacher_output = None + else: + # 2. Run Teacher to get soft targets (logits) + # The strategy ensures these are stop_gradient-ed + teacher_output = self.strategy.get_teacher_outputs(self.teacher_model, inputs) + + # 3. Return extended object so fields are available for Student training step + # pylint: disable=unexpected-keyword-arg + return MaxTextTrainingInput( + input_tokens=input_data.input_tokens, + input_mask=input_data.input_mask, + teacher_output=teacher_output, + positions=input_data.positions, + decoder_segment_ids=input_data.decoder_segment_ids, + targets=input_data.targets, + ) + + def _post_process_train_step(self, aux: Dict[str, jax.Array]) -> None: + """Extracts auxiliary metrics from the strategy and buffers them for logging.""" + if self._buffered_train_metrics is None: + return + + # 'aux' contains the dictionary we returned from compute_loss: + # {"distill/soft_loss": ..., "distill/hard_loss": ...} + for name, value in aux.items(): + # We accumulate these values. PeftTrainer handles the averaging. + # The structure expected is: dict[metric_name, (list_of_values, aggregation_fn)] + if name not in self._buffered_train_metrics.additional_metrics: + self._buffered_train_metrics.additional_metrics[name] = ([], np.mean) + + self._buffered_train_metrics.additional_metrics[name][0].append(value) + + +# ----------------------------------------------------------------------------- +# Data Loading Adapter +# ----------------------------------------------------------------------------- + + +class MaxTextToTunixIterator: + """Adapts the raw dictionary output of MaxText's data loader to Tunix objects. + + MaxText's `train_utils.create_data_iterator` yields a dictionary. + Tunix expects an object with specific attributes (input_tokens, etc.). + """ + + def __init__(self, maxtext_iterator: Iterator): + """Initializes the adapter. + + Args: + maxtext_iterator: The upstream iterator created by MaxText's input pipeline. + """ + self._iterator = maxtext_iterator + + def __iter__(self): + """Returns self as the iterator.""" + return self + + def __next__(self) -> MaxTextTrainingInput: + """Fetches the next batch and converts it to the Tunix data class. + + Returns: + A MaxTextTrainingInput object containing the batch data. + + Raises: + StopIteration: If the upstream iterator is exhausted. + """ + batch = next(self._iterator) + + # Ensure segmentation exists, default to ones if missing (standard non-packed) + if "inputs_segmentation" in batch: + input_mask = batch["inputs_segmentation"] != 0 + seg_ids = batch["inputs_segmentation"] + else: + # Fallback for non-packed datasets + input_mask = jnp.ones_like(batch["inputs"], dtype=jnp.bool_) + seg_ids = None + + # pylint: disable=unexpected-keyword-arg + return MaxTextTrainingInput( + input_tokens=batch["inputs"], + input_mask=input_mask, + teacher_output=None, + positions=batch["inputs_position"], + decoder_segment_ids=seg_ids, + targets=batch["targets"], + ) + + +# ----------------------------------------------------------------------------- +# Model Loading +# ----------------------------------------------------------------------------- +def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh) -> nnx.Module: + """Loads a MaxText model and wraps it in a Tunix adapter. + + Args: + config: The configuration object for this specific model (Student or Teacher). + mesh: The global device mesh for sharding weights. + + Returns: + A TunixMaxTextAdapter instance wrapping the loaded MaxText model. + """ + max_logging.log(f"Initializing model: {config.model_name}...") + model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh) + + with mesh: + tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=True) + return tunix_model + + +# ----------------------------------------------------------------------------- +# Main Training Loop +# ----------------------------------------------------------------------------- + + +def train_distill(student_config: pyconfig.HyperParameters, teacher_config: pyconfig.HyperParameters) -> None: + """Main distillation training loop. + + Orchestrates the loading of both student and teacher models, configures the + distillation strategy, and executes the training loop via the Tunix Trainer. + + Args: + student_config: Configuration object for the Student model (learnable). + teacher_config: Configuration object for the Teacher model (frozen). + """ + # Validate vocab size match between Student and Teacher + if student_config.vocab_size != teacher_config.vocab_size: + raise ValueError( + f"Vocab size mismatch! Student: {student_config.vocab_size}, Teacher: {teacher_config.vocab_size}. " + "Distillation requires matching vocabularies." + ) + + # 1. Setup Mesh + devices = jax.devices() + devices_array = maxtext_utils.create_device_mesh(student_config, devices) + mesh = jax.sharding.Mesh(devices_array, student_config.mesh_axes) + + # 2. Load Models & Tokenizer Info + tok = tokenizer.build_tokenizer( + tokenizer_path=student_config.tokenizer_path, + tokenizer_type=student_config.tokenizer_type, + add_bos=student_config.add_bos, + add_eos=student_config.add_eos, + hf_access_token=student_config.hf_access_token, + dataset_type=student_config.dataset_type, + ) + pad_id = tok.pad_id if tok.pad_id is not None else 0 + + max_logging.log(f"Loading Student from {student_config.load_parameters_path}...") + _log_config_details(student_config, "Student") + student_model = get_maxtext_model(student_config, mesh) + + max_logging.log(f"Loading Teacher from {teacher_config.load_parameters_path}...") + _log_config_details(teacher_config, "Teacher") + teacher_model = get_maxtext_model(teacher_config, mesh) + + # 3. Define Distillation Strategy + def labels_fn(targets, **kwargs): + """Converts integer targets to masked one-hot vectors for hard label loss.""" + del kwargs # Unused + one_hot = jax.nn.one_hot(targets, student_config.vocab_size) + mask = jnp.not_equal(targets, pad_id).astype(one_hot.dtype)[..., None] + return one_hot * mask + + def model_forward_fn(model, input_tokens, positions, attention_mask, decoder_segment_ids=None, cache=None, **kwargs): + """Forward pass wrapper for the MaxText models (Student and Teacher).""" + del kwargs # Unused + # Tunix adapter ensures __call__ signature matches this + outputs = model( + input_tokens=input_tokens, + positions=positions, + cache=cache, + attention_mask=attention_mask, + decoder_segment_ids=decoder_segment_ids, # Support sequence packing + ) + return outputs[0] # Return logits only + + # Both Student and Teacher use the same forward logic via the adapter + student_forward_fn = model_forward_fn + teacher_forward_fn = model_forward_fn + + # Use Monitored strategy to enable KL/Soft/Hard Loss logging + strategy = MonitoredLogitStrategy( + student_forward_fn=student_forward_fn, + teacher_forward_fn=teacher_forward_fn, + labels_fn=labels_fn, + temperature=student_config.distill_temperature, + alpha=student_config.distill_alpha, + ) + + # 4. Optimizer & Config + optimizer = get_distillation_optimizer(student_config, student_config.steps) + + checkpointing_options = checkpoint.CheckpointManagerOptions( + save_interval_steps=student_config.checkpoint_period, max_to_keep=student_config.max_num_checkpoints_to_keep + ) + + profiler_options = None + if student_config.profiler == "xplane": + profiler_options = profiler.ProfilerOptions( + log_dir=student_config.tensorboard_dir, + skip_first_n_steps=student_config.skip_first_n_steps_for_profiler, + profiler_steps=student_config.profiler_steps, + set_profile_options=False, + ) + + metrics_logging_options = metrics_logger.MetricsLoggerOptions( + log_dir=student_config.tensorboard_dir, flush_every_n_steps=student_config.log_period + ) + + train_config = distillation_trainer.TrainingConfig( + max_steps=student_config.steps, + eval_every_n_steps=student_config.eval_interval, + metrics_logging_options=metrics_logging_options, + profiler_options=profiler_options, + checkpoint_root_directory=student_config.checkpoint_dir, + checkpointing_options=checkpointing_options, + ) + + # 5. Initialize Trainer + trainer = MaxTextDistillationTrainer( + student_model=student_model, + teacher_model=teacher_model, + strategy=strategy, + optimizer=optimizer, + training_config=train_config, + ) + trainer.is_managed_externally = True + + # Force enable auxiliary metric logging + trainer._has_aux = True # pylint: disable=protected-access + + # 6. Configure Input Mapping + # Maps the attributes of MaxTextTrainingInput to the kwargs expected by the models + trainer = trainer.with_gen_model_input_fn( + lambda batch: { + "input_tokens": batch.input_tokens, + "positions": batch.positions, + "attention_mask": batch.input_mask, + "decoder_segment_ids": batch.decoder_segment_ids, + "targets": batch.targets, # Passed to strategy (labels_fn) + "cache": None, + } + ) + + # 7. Data Iterators + # We use MaxText's native create_data_iterator which creates both train and eval iterators + # based on the config parameters (dataset_type, eval_interval, etc.) + max_logging.log("Initializing Data Iterators via MaxText pipeline...") + raw_train_iter, raw_eval_iter = train_utils.create_data_iterator(student_config, mesh) + + train_iter = MaxTextToTunixIterator(raw_train_iter) + + eval_iter = None + if raw_eval_iter is not None: + max_logging.log("Evaluation iterator successfully initialized.") + eval_iter = MaxTextToTunixIterator(raw_eval_iter) + elif student_config.eval_interval > 0: + max_logging.log("Warning: eval_interval > 0 but create_data_iterator returned None for eval_iter.") + + # 8. Train + max_logging.log("Starting Distillation Training...") + with mesh, nn_partitioning.axis_rules(student_config.logical_axis_rules): + # Pass both iterators to the trainer + trainer.train(train_iter, eval_iter) + + # 9. Final Save (Conditional) + if student_config.save_checkpoint_on_completion: + should_save = student_config.steps % student_config.checkpoint_period + + if should_save: + max_logging.log(f"Saving final checkpoint to {student_config.checkpoint_dir}...") + try: + saved = trainer.checkpoint_manager.save( + trainer.train_steps, trainer.model, save_only_lora_params=getattr(trainer, "_lora_enabled", False), force=True + ) + if saved: + # Ensure underlying orbax manager finishes writing + # pylint: disable=protected-access + if trainer.checkpoint_manager._checkpoint_manager is not None: + trainer.checkpoint_manager._checkpoint_manager.wait_until_finished() + # pylint: enable=protected-access + max_logging.log("Final checkpoint saved.") + + except Exception as e: # pylint: disable=broad-exception-caught + max_logging.log(f"Warning: Failed to save final checkpoint: {e}") + + else: + max_logging.log("Waiting for automatic periodic checkpoint to finish...") + trainer.checkpoint_manager.wait_until_finished() + + trainer.close() + max_logging.log("Distillation Complete.") + + +def main(argv: Sequence[str]) -> None: + """Entry point for the script. + + Parses configuration, isolates Student and Teacher overrides, and triggers the + training loop. + + Args: + argv: List of command-line arguments. Expects [script_name, config_file, ...]. + """ + # 1. Parse Global Config to extract Overrides + global_config = pyconfig.initialize(argv) + + # 2. Initialize STUDENT Config + # Order of precedence: YAML < CLI < kwargs (student_overrides). + student_overrides = global_config.student_overrides + student_config = pyconfig.initialize(argv, **student_overrides) + + # 3. Initialize TEACHER Config + # We isolate the Teacher from Student CLI arguments (like pruning params). + teacher_overrides = global_config.teacher_overrides + + # Ensure load_parameters_path is set (check overrides, then env var) + if not teacher_overrides.get("load_parameters_path"): + ckpt_path = os.environ.get("TEACHER_CHECKPOINT_PATH") + if ckpt_path: + teacher_overrides["load_parameters_path"] = ckpt_path + else: + max_logging.log("Warning: No load_parameters_path found for Teacher.") + + # Construct sanitized argv: [script_name, config_file] + # This ensures flags like `num_query_heads=16` passed in CLI don't affect the Teacher. + teacher_argv = [argv[0], argv[1]] + teacher_config = pyconfig.initialize(teacher_argv, **teacher_overrides) + + # 4. Run Training + train_distill(student_config, teacher_config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/MaxText/integration/tunix/tunix_adapter.py b/src/MaxText/integration/tunix/tunix_adapter.py index 6ffceb4047..9bd36b3360 100644 --- a/src/MaxText/integration/tunix/tunix_adapter.py +++ b/src/MaxText/integration/tunix/tunix_adapter.py @@ -57,6 +57,7 @@ def __call__( positions: Array, # [B, L] cache: Optional[Any], # Tunix currently passes None from Trainers attention_mask: Optional[Array], # [B, L, L] or None + decoder_segment_ids: Optional[Array] = None, output_hidden_states: bool = False, # ignored ) -> Tuple[Array, None]: """Forward compatible with Tunix Trainers default loss. @@ -65,8 +66,7 @@ def __call__( logits = self.base( decoder_input_tokens=input_tokens, decoder_positions=positions, - # TODO: @mazumdera - add support for packing - decoder_segment_ids=None, + decoder_segment_ids=decoder_segment_ids, ) return logits, None diff --git a/tests/train_distill_test.py b/tests/train_distill_test.py new file mode 100644 index 0000000000..1fb2ef6562 --- /dev/null +++ b/tests/train_distill_test.py @@ -0,0 +1,207 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Unit tests for the Distillation Trainer.""" + +import unittest +from unittest import mock +import jax +import jax.numpy as jnp +import numpy as np +import optax +from absl.testing import absltest + +# Import the module under test +from MaxText.distillation import train_distill +from MaxText import pyconfig + + +# pylint: disable=protected-access +class TrainDistillTest(unittest.TestCase): + + def test_maxtext_to_tunix_iterator(self): + """Verifies the adapter correctly converts dictionary batches to dataclasses.""" + + # 1. Create a dummy iterator that simulates MaxText data loader + dummy_batch = { + "inputs": np.array([[10, 11]]), + "inputs_position": np.array([[0, 1]]), + "inputs_segmentation": np.array([[1, 1]]), + "targets": np.array([[11, 12]]), + } + + dummy_iter = iter([dummy_batch]) + + # 2. Initialize Adapter + adapter = train_distill.MaxTextToTunixIterator(dummy_iter) + + # 3. Fetch Batch + tunix_input = next(adapter) + + # 4. Verify Fields + self.assertIsInstance(tunix_input, train_distill.MaxTextTrainingInput) + np.testing.assert_array_equal(tunix_input.input_tokens, dummy_batch["inputs"]) + np.testing.assert_array_equal(tunix_input.positions, dummy_batch["inputs_position"]) + np.testing.assert_array_equal(tunix_input.decoder_segment_ids, dummy_batch["inputs_segmentation"]) + np.testing.assert_array_equal(tunix_input.targets, dummy_batch["targets"]) + + # Verify constructed mask (segmentation != 0) + expected_mask = dummy_batch["inputs_segmentation"] != 0 + np.testing.assert_array_equal(tunix_input.input_mask, expected_mask) + + def test_maxtext_to_tunix_iterator_packed_fallback(self): + """Verifies fallback behavior when segmentation is missing.""" + dummy_batch = { + "inputs": np.array([[10, 11]]), + "inputs_position": np.array([[0, 1]]), + "targets": np.array([[11, 12]]), + } + dummy_iter = iter([dummy_batch]) + adapter = train_distill.MaxTextToTunixIterator(dummy_iter) + tunix_input = next(adapter) + + self.assertIsNone(tunix_input.decoder_segment_ids) + self.assertTrue(np.all(tunix_input.input_mask)) + + def test_prepare_inputs_logic(self): + """Verifies the filtering and teacher call logic in the custom trainer.""" + # 1. Initialize Trainer (bypass init) + # pylint: disable=no-value-for-parameter + trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) + + # Setup mocks + trainer._mode = "train" + trainer.strategy = mock.Mock() + trainer.teacher_model = mock.Mock() + trainer.model = mock.Mock() + trainer.gen_model_input_fn = lambda x: {"inputs": {"some_key": "some_val"}} + + # 2. Setup Input + # pylint: disable=unexpected-keyword-arg + input_data = train_distill.MaxTextTrainingInput( + input_tokens=jnp.array([[1]]), + input_mask=jnp.array([[True]]), + positions=jnp.array([[0]]), + targets=jnp.array([[1]]), + ) + + # 3. Mock Strategy Output + fake_teacher_logits = jnp.zeros((1, 1, 10)) + trainer.strategy.get_teacher_outputs.return_value = fake_teacher_logits + + # 4. Run + result = trainer._prepare_inputs(input_data) + + # 5. Verify + trainer.strategy.get_teacher_outputs.assert_called_once() + self.assertIsNotNone(result.teacher_output) + self.assertEqual(result.teacher_output.shape, (1, 1, 10)) + + def test_optimizer_factory(self): + """Verifies the optimizer factory injects hyperparams and handles configs.""" + # Mock config + config = mock.Mock(spec=pyconfig.HyperParameters) + config.learning_rate = 1e-3 + config.opt_type = "adamw" + config.adam_b1 = 0.9 + config.adam_b2 = 0.99 + config.adam_eps = 1e-8 + config.adam_eps_root = 0.0 + config.adam_weight_decay = 0.0 + config.mu_dtype = "float32" + config.gradient_clipping_threshold = 1.0 + config.warmup_steps_fraction = 0.1 + config.cosine_learning_rate_final_fraction = 0.1 + + # 1. Test Valid Creation + opt = train_distill.get_distillation_optimizer(config, max_train_steps=100) + + # Initialize to check state structure + params = {"a": jnp.array([0.0])} + state = opt.init(params) + + # Verify InjectHyperparamsState is the top-level state (required for Tunix logging) + # Note: When injecting a schedule (callable), optax returns InjectStatefulHyperparamsState + self.assertTrue( + isinstance(state, (optax.InjectHyperparamsState, optax.InjectStatefulHyperparamsState)), + f"State is {type(state)}, expected InjectHyperparamsState or InjectStatefulHyperparamsState", + ) + self.assertIn("learning_rate", state.hyperparams) + + # 2. Test Muon Rejection + config.opt_type = "muon" + with self.assertRaisesRegex(ValueError, "Muon optimizer is not currently supported"): + train_distill.get_distillation_optimizer(config, max_train_steps=100) + + def test_monitored_strategy(self): + """Verifies the strategy calculates metrics and returns the correct tuple.""" + strategy = train_distill.MonitoredLogitStrategy( + student_forward_fn=lambda m, **k: None, + teacher_forward_fn=lambda m, **k: None, + labels_fn=lambda t: t, + temperature=1.0, + alpha=0.5, + ) + + # Dummy inputs (batch=1, seq=2, vocab=4) + # Note: Shapes must align for broadcasting + student_logits = jnp.array([[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]]) * 10 + teacher_logits = jnp.array([[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]]) * 10 + + # Labels must be One-Hot Encoded to match logits shape (1, 2, 4) + labels_indices = jnp.array([[0, 1]]) + labels = jax.nn.one_hot(labels_indices, 4) + + # Run calculation + _, metrics = strategy.compute_loss(student_logits, teacher_logits, labels) + + # Verify structure + self.assertIsInstance(metrics, dict) + + # Check keys required for TensorBoard + expected_keys = ["distill/soft_loss", "distill/hard_loss", "distill/kl_div", "distill/teacher_loss"] + for key in expected_keys: + self.assertIn(key, metrics) + + # Since inputs match perfectly, KL should be near 0 + self.assertLess(metrics["distill/kl_div"], 1e-5) + + def test_post_process_train_step(self): + """Verifies metrics are moved from aux dict to the trainer buffer.""" + # pylint: disable=no-value-for-parameter + trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) + + # Setup MetricsBuffer mock + mock_buffer = mock.Mock() + mock_buffer.additional_metrics = {} + trainer._buffered_train_metrics = mock_buffer + + # Simulate auxiliary output from strategy + aux_metrics = {"distill/kl_div": jnp.array(0.5), "distill/soft_loss": jnp.array(1.2)} + + # Run Hook + trainer._post_process_train_step(aux_metrics) + + # Verify buffer updated + self.assertIn("distill/kl_div", mock_buffer.additional_metrics) + self.assertIn("distill/soft_loss", mock_buffer.additional_metrics) + + # Verify value appended to list + values_list = mock_buffer.additional_metrics["distill/kl_div"][0] + self.assertEqual(values_list[0], 0.5) + + +if __name__ == "__main__": + absltest.main()