Skip to content

Conversation

@kocchop
Copy link
Collaborator

@kocchop kocchop commented Dec 31, 2025

Description

Enables sequence packing for context parallelism with ring strategy using TransformerEngine's DotProductAttention. Includes comprehensive GPU tests for ring attention with packing for sm90+.

  • Currently supports packing only for ring attention
  • Replaced local sequence reordering with TE reorder_causal_load_balancing api
  • Currently the load balancing strategy is automatically picked based on the packing config

Tests

Added a GPU integration test that works for sm90+.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Dec 31, 2025

Codecov Report

❌ Patch coverage is 0% with 23 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/train_utils.py 0.00% 8 Missing ⚠️
src/MaxText/layers/attention_op.py 0.00% 7 Missing ⚠️
src/MaxText/max_utils.py 0.00% 4 Missing ⚠️
src/MaxText/maxtext_utils.py 0.00% 4 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a couple of nits but lgtm

# Handle packing configurations
if self.config.packing and self.config.dataset_type != "synthetic":
if using_context_parallelism and not using_load_balanced_ring_cp:
raise AssertionError("Packing is only supported for load balanced ring attention with context parallelism.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: AssertionError feels weird here to me. Maybe an argumenterror?



def get_reorder_callable(cp_size, shard_mode):
def get_reorder_callable(cp_size, shard_mode, reorder_strategy=0): # 0=DualChunkSwap, 1=Striped
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I read this late at night I imagine you're using an integer here so it's comprehensible by JAX but could this be made into an enum without breaking things (at worse using .value?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants