Skip to content

Conversation

@KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Dec 24, 2025

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  1. Current tests run all possible combinations for L1 and L2 where dp*cp*tp <= num gpus. Below is the L1 dist timing for TE 2.11 (B200x8)
2157 collected
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test                                                         |  27x |    2.84s | avg:   0.11s
test_context_parallel_allgather_attn                         | 320x |  917.64s | avg:   2.87s
test_context_parallel_allgather_attn_shardy                  |  80x |  503.19s | avg:   6.29s
test_context_parallel_allgather_striped_attn                 | 320x |  338.28s | avg:   1.06s
test_context_parallel_ring_attn                              | 1280x | 1662.10s | avg:   1.30s
test_context_parallel_ring_attn_shardy                       |  40x |   46.24s | avg:   1.16s
test_cross_attn                                              |  18x |   31.73s | avg:   1.76s
test_self_attn                                               |  54x |  125.43s | avg:   2.32s
test_self_attn_shardy                                        |  18x |   16.87s | avg:   0.94s
================================================================================
TOTAL RUNTIME                                                |      | 3644.31s |
================================================================================

This PR runs only those L1 and L2 combinations where dp*cp*tp==num gpus. Below is the L1 dist timing for this PR (B200x8)

1137 collected
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test                                                         |  27x |    2.37s | avg:   0.09s
test_context_parallel_allgather_attn                         | 160x |  344.81s | avg:   2.16s
test_context_parallel_allgather_attn_shardy                  |  40x |  167.59s | avg:   4.19s
test_context_parallel_allgather_striped_attn                 | 160x |  190.06s | avg:   1.19s
test_context_parallel_ring_attn                              | 640x |  729.42s | avg:   1.14s
test_context_parallel_ring_attn_shardy                       |  20x |   16.94s | avg:   0.85s
test_cross_attn                                              |  18x |   23.38s | avg:   1.30s
test_self_attn                                               |  54x |  118.77s | avg:   2.20s
test_self_attn_shardy                                        |  18x |   12.75s | avg:   0.71s
================================================================================
TOTAL RUNTIME                                                |      | 1606.09s |
================================================================================

There is a reduction of 1020 (2157-1137) tests collected owing to the change in this PR.
For the test_context_parallel tests, the number of test are halved in number in this PR as only the test cases for dp*cp*tp==8 are collected but not those for dp*cp*tp==4 and dp*cp*tp==2. This is not that big a problem in CI as we run H100x4 and GB200x4 so test cases for dp*cp*tp==4 will be covered in there.

Cons of this change:

  1. dp*cp*tp==2 test will not be covered in it's current form. TODO: If coverage for this is needed, CI could set cuda_visible_devices=0,1 for any of these configs an run these tests as well
  2. With current CI configs, the current tests would run dp*cp*tp<=8 for B200, however, with this PR, we will only run dp*cp*tp==8 cases. The current tests would run dp*cp*tp<=4 for H100, however, with this PR, we will only run dp*cp*tp==4 cases. The current tests would run dp*cp*tp<=4 for GB200, however, with this PR, we will only run dp*cp*tp==4 cases. Overall test cases would still be the same but we just would not have all combinations available for a given CI config (GPU arch) running on it

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

KshitijLakhani and others added 2 commits December 24, 2025 19:19
…only those cp,dp,tp combinations are picked where cp*dp*tp is equal to num gpus

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani changed the title Refactor and trim TE JAX Attn testing [JAX] Refactor and trim TE JAX Attn testing Dec 24, 2025
@KshitijLakhani KshitijLakhani self-assigned this Dec 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant