From 6857b287798842c69ab082295e0dd15fb48a7182 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 24 Dec 2025 19:19:31 +0000 Subject: [PATCH 1/3] Pick a leaner set of combinations for TE JAX CP attn tests such that only those cp,dp,tp combinations are picked where cp*dp*tp is equal to num gpus Signed-off-by: Kshitij Lakhani --- tests/jax/distributed_test_base.py | 7 +++++-- tests/jax/utils.py | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 137fa480ddb..f86f81ec48f 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -12,7 +12,7 @@ from transformer_engine.jax.sharding import MeshResource -from utils import assert_allclose, is_devices_enough +from utils import assert_allclose, is_devices_enough, is_devices_equal def generate_configs(): @@ -49,7 +49,10 @@ def generate_context_parallel_configs_for_attn(): TP_sizes = (1, 2) for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): ndev = cp * tp * dp - if is_devices_enough(ndev): + # Run only those dp,cp,tp combinations which require exactly ndev GPUs. + # For e.g., if ndev=8 and num GPUs is 8, thent hose combinations will be picked. + # However, if ndev=4, but num GPUs is 8, then those combinations will not be picked. To pick such a combination, one can set CUDA_VISIBLE_DEVICES=0,1,2,3. + if is_devices_equal(ndev): # Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations) if cp != 1: configsL1.append( diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 7194e387c73..c3311395a0d 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -46,6 +46,12 @@ def is_devices_enough(required): """ return len(jax.devices()) >= required +def is_devices_equal(required): + """ + Check if the available GPUs is exactly equal + """ + return len(jax.devices()) == required + def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]: # Generate broadcast dims for drop_path. From d6a29518a2c5e21a537a42dab22f6f229f3884bd Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 8 Jan 2026 01:29:51 +0000 Subject: [PATCH 2/3] Consolidate the test cases run for different B,S,H,D and QKV layout Signed-off-by: Kshitij Lakhani --- tests/jax/test_fused_attn.py | 73 ++++++++++++------------------------ 1 file changed, 24 insertions(+), 49 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 49372fda1d4..85b38e5f538 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1056,56 +1056,31 @@ def check_dqkv(primitive, reference, pad, idx): ], ) @pytest.mark.parametrize( - "qkv_layout", + "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout", [ - pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"), - pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"), - pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), - pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"), - pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"), - pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"), - ], -) -@pytest.mark.parametrize( - "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype", - [ - pytest.param( - 2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF" - ), - pytest.param( - 2, - 512, - 1024, - 12, - 12, - 64, - 64, - jnp.bfloat16, - id="2-512-1024-12-12-64-64-BF16-CROSS", - ), - pytest.param( - 2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA" - ), - pytest.param( - 4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF" - ), - pytest.param( - 4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF" - ), - pytest.param( - 2, - 2048, - 1024, - 12, - 12, - 64, - 32, - jnp.bfloat16, - id="2-2048-1024-12-12-64-32-BF16-CROSS", - ), - pytest.param( - 2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA" - ), + # large + QKV_PACKED and RAGGED_QKV_PACKED? + pytest.param(2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.BS3HD, id="2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED"), + pytest.param(2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.T3HD, id="2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED"), + # mid + cross + KV_PACKED and RAGGED_KV_PACKED + pytest.param(2, 512, 1024, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.BSHD_BS2HD, id="2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED"), + pytest.param(2, 512, 1024, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.THD_T2HD, id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED"), + # large + cross + diff hidden v dim + SEPARATE and RAGGED_SEPARATE ? + pytest.param(2, 2048, 1024, 12, 12, 64, 32, jnp.bfloat16, QKVLayout.BSHD_BSHD_BSHD, id="2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE"), + pytest.param(2, 2048, 1024, 12, 12, 64, 32, jnp.bfloat16, QKVLayout.THD_THD_THD, id="2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE"), + # large + gqa + KV_PACKED and RAGGED_KV_PACKED ? + pytest.param(2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, QKVLayout.BSHD_BS2HD, id="2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED"), + pytest.param(2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, QKVLayout.THD_T2HD, id="2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED"), + + # small + fp16 + diff hidden v dim + QKV_PACKED and RAGGED_QKV_PACKED ? + pytest.param(4, 128, 128, 16, 16, 64, 32, jnp.float16, QKVLayout.BS3HD, id="4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED"), + pytest.param(4, 128, 128, 16, 16, 64, 32, jnp.float16, QKVLayout.T3HD, id="4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED"), + # small + fp16 + KV_PACKED and RAGGED_KV_PACKED ? + pytest.param(4, 128, 128, 16, 16, 64, 64, jnp.float16, QKVLayout.BSHD_BS2HD, id="4-128-128-16-16-64-64-FP16-SELF-KV_PACKED"), + pytest.param(4, 128, 128, 16, 16, 64, 64, jnp.float16, QKVLayout.THD_T2HD, id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED"), + # large + fp16 + gqa + diff hidden v dim + SEPARATE and RAGGED_SEPARATE ? + # TODO: Consider making this a CROSS case ? + pytest.param(2, 2048, 2048, 12, 6, 128, 64, jnp.float16, QKVLayout.BSHD_BSHD_BSHD, id="2-2048-2048-12-6-128-64-FP16-GQA-SEPARATE"), + pytest.param(2, 2048, 2048, 12, 6, 128, 64, jnp.float16, QKVLayout.THD_THD_THD, id="2-2048-2048-12-6-128-64-FP16-GQA-RAGGED_SEPARATE"), ], ) @pytest.mark.parametrize( From 2dd5068a90b86d3651098cc80cf2db98ba7776c1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 01:32:09 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/distributed_test_base.py | 2 +- tests/jax/test_fused_attn.py | 183 ++++++++++++++++++++++++++--- tests/jax/utils.py | 1 + 3 files changed, 170 insertions(+), 16 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index f86f81ec48f..1593f50f042 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -49,7 +49,7 @@ def generate_context_parallel_configs_for_attn(): TP_sizes = (1, 2) for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): ndev = cp * tp * dp - # Run only those dp,cp,tp combinations which require exactly ndev GPUs. + # Run only those dp,cp,tp combinations which require exactly ndev GPUs. # For e.g., if ndev=8 and num GPUs is 8, thent hose combinations will be picked. # However, if ndev=4, but num GPUs is 8, then those combinations will not be picked. To pick such a combination, one can set CUDA_VISIBLE_DEVICES=0,1,2,3. if is_devices_equal(ndev): diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 85b38e5f538..ebf0543697a 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1059,28 +1059,181 @@ def check_dqkv(primitive, reference, pad, idx): "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout", [ # large + QKV_PACKED and RAGGED_QKV_PACKED? - pytest.param(2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.BS3HD, id="2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED"), - pytest.param(2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.T3HD, id="2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED"), + pytest.param( + 2, + 2048, + 2048, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.BS3HD, + id="2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED", + ), + pytest.param( + 2, + 2048, + 2048, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.T3HD, + id="2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED", + ), # mid + cross + KV_PACKED and RAGGED_KV_PACKED - pytest.param(2, 512, 1024, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.BSHD_BS2HD, id="2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED"), - pytest.param(2, 512, 1024, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.THD_T2HD, id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED"), + pytest.param( + 2, + 512, + 1024, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.BSHD_BS2HD, + id="2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED", + ), + pytest.param( + 2, + 512, + 1024, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.THD_T2HD, + id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED", + ), # large + cross + diff hidden v dim + SEPARATE and RAGGED_SEPARATE ? - pytest.param(2, 2048, 1024, 12, 12, 64, 32, jnp.bfloat16, QKVLayout.BSHD_BSHD_BSHD, id="2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE"), - pytest.param(2, 2048, 1024, 12, 12, 64, 32, jnp.bfloat16, QKVLayout.THD_THD_THD, id="2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE"), + pytest.param( + 2, + 2048, + 1024, + 12, + 12, + 64, + 32, + jnp.bfloat16, + QKVLayout.BSHD_BSHD_BSHD, + id="2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE", + ), + pytest.param( + 2, + 2048, + 1024, + 12, + 12, + 64, + 32, + jnp.bfloat16, + QKVLayout.THD_THD_THD, + id="2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE", + ), # large + gqa + KV_PACKED and RAGGED_KV_PACKED ? - pytest.param(2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, QKVLayout.BSHD_BS2HD, id="2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED"), - pytest.param(2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, QKVLayout.THD_T2HD, id="2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED"), - + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 64, + 64, + jnp.bfloat16, + QKVLayout.BSHD_BS2HD, + id="2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED", + ), + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 64, + 64, + jnp.bfloat16, + QKVLayout.THD_T2HD, + id="2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED", + ), # small + fp16 + diff hidden v dim + QKV_PACKED and RAGGED_QKV_PACKED ? - pytest.param(4, 128, 128, 16, 16, 64, 32, jnp.float16, QKVLayout.BS3HD, id="4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED"), - pytest.param(4, 128, 128, 16, 16, 64, 32, jnp.float16, QKVLayout.T3HD, id="4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED"), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 32, + jnp.float16, + QKVLayout.BS3HD, + id="4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED", + ), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 32, + jnp.float16, + QKVLayout.T3HD, + id="4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED", + ), # small + fp16 + KV_PACKED and RAGGED_KV_PACKED ? - pytest.param(4, 128, 128, 16, 16, 64, 64, jnp.float16, QKVLayout.BSHD_BS2HD, id="4-128-128-16-16-64-64-FP16-SELF-KV_PACKED"), - pytest.param(4, 128, 128, 16, 16, 64, 64, jnp.float16, QKVLayout.THD_T2HD, id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED"), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 64, + jnp.float16, + QKVLayout.BSHD_BS2HD, + id="4-128-128-16-16-64-64-FP16-SELF-KV_PACKED", + ), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 64, + jnp.float16, + QKVLayout.THD_T2HD, + id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED", + ), # large + fp16 + gqa + diff hidden v dim + SEPARATE and RAGGED_SEPARATE ? # TODO: Consider making this a CROSS case ? - pytest.param(2, 2048, 2048, 12, 6, 128, 64, jnp.float16, QKVLayout.BSHD_BSHD_BSHD, id="2-2048-2048-12-6-128-64-FP16-GQA-SEPARATE"), - pytest.param(2, 2048, 2048, 12, 6, 128, 64, jnp.float16, QKVLayout.THD_THD_THD, id="2-2048-2048-12-6-128-64-FP16-GQA-RAGGED_SEPARATE"), + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 128, + 64, + jnp.float16, + QKVLayout.BSHD_BSHD_BSHD, + id="2-2048-2048-12-6-128-64-FP16-GQA-SEPARATE", + ), + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 128, + 64, + jnp.float16, + QKVLayout.THD_THD_THD, + id="2-2048-2048-12-6-128-64-FP16-GQA-RAGGED_SEPARATE", + ), ], ) @pytest.mark.parametrize( diff --git a/tests/jax/utils.py b/tests/jax/utils.py index c3311395a0d..39307075024 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -46,6 +46,7 @@ def is_devices_enough(required): """ return len(jax.devices()) >= required + def is_devices_equal(required): """ Check if the available GPUs is exactly equal