diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 137fa480ddb..1593f50f042 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -12,7 +12,7 @@ from transformer_engine.jax.sharding import MeshResource -from utils import assert_allclose, is_devices_enough +from utils import assert_allclose, is_devices_enough, is_devices_equal def generate_configs(): @@ -49,7 +49,10 @@ def generate_context_parallel_configs_for_attn(): TP_sizes = (1, 2) for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): ndev = cp * tp * dp - if is_devices_enough(ndev): + # Run only those dp,cp,tp combinations which require exactly ndev GPUs. + # For e.g., if ndev=8 and num GPUs is 8, thent hose combinations will be picked. + # However, if ndev=4, but num GPUs is 8, then those combinations will not be picked. To pick such a combination, one can set CUDA_VISIBLE_DEVICES=0,1,2,3. + if is_devices_equal(ndev): # Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations) if cp != 1: configsL1.append( diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 49372fda1d4..ebf0543697a 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1056,41 +1056,70 @@ def check_dqkv(primitive, reference, pad, idx): ], ) @pytest.mark.parametrize( - "qkv_layout", - [ - pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"), - pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"), - pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), - pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"), - pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"), - pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"), - ], -) -@pytest.mark.parametrize( - "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype", + "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout", [ + # large + QKV_PACKED and RAGGED_QKV_PACKED? pytest.param( - 2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF" + 2, + 2048, + 2048, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.BS3HD, + id="2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED", ), pytest.param( 2, - 512, - 1024, + 2048, + 2048, 12, 12, 64, 64, jnp.bfloat16, - id="2-512-1024-12-12-64-64-BF16-CROSS", + QKVLayout.T3HD, + id="2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED", ), + # mid + cross + KV_PACKED and RAGGED_KV_PACKED pytest.param( - 2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA" + 2, + 512, + 1024, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.BSHD_BS2HD, + id="2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED", ), pytest.param( - 4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF" + 2, + 512, + 1024, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.THD_T2HD, + id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED", ), + # large + cross + diff hidden v dim + SEPARATE and RAGGED_SEPARATE ? pytest.param( - 4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF" + 2, + 2048, + 1024, + 12, + 12, + 64, + 32, + jnp.bfloat16, + QKVLayout.BSHD_BSHD_BSHD, + id="2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE", ), pytest.param( 2, @@ -1101,10 +1130,109 @@ def check_dqkv(primitive, reference, pad, idx): 64, 32, jnp.bfloat16, - id="2-2048-1024-12-12-64-32-BF16-CROSS", + QKVLayout.THD_THD_THD, + id="2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE", + ), + # large + gqa + KV_PACKED and RAGGED_KV_PACKED ? + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 64, + 64, + jnp.bfloat16, + QKVLayout.BSHD_BS2HD, + id="2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED", + ), + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 64, + 64, + jnp.bfloat16, + QKVLayout.THD_T2HD, + id="2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED", ), + # small + fp16 + diff hidden v dim + QKV_PACKED and RAGGED_QKV_PACKED ? pytest.param( - 2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA" + 4, + 128, + 128, + 16, + 16, + 64, + 32, + jnp.float16, + QKVLayout.BS3HD, + id="4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED", + ), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 32, + jnp.float16, + QKVLayout.T3HD, + id="4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED", + ), + # small + fp16 + KV_PACKED and RAGGED_KV_PACKED ? + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 64, + jnp.float16, + QKVLayout.BSHD_BS2HD, + id="4-128-128-16-16-64-64-FP16-SELF-KV_PACKED", + ), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 64, + jnp.float16, + QKVLayout.THD_T2HD, + id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED", + ), + # large + fp16 + gqa + diff hidden v dim + SEPARATE and RAGGED_SEPARATE ? + # TODO: Consider making this a CROSS case ? + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 128, + 64, + jnp.float16, + QKVLayout.BSHD_BSHD_BSHD, + id="2-2048-2048-12-6-128-64-FP16-GQA-SEPARATE", + ), + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 128, + 64, + jnp.float16, + QKVLayout.THD_THD_THD, + id="2-2048-2048-12-6-128-64-FP16-GQA-RAGGED_SEPARATE", ), ], ) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 7194e387c73..39307075024 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -47,6 +47,13 @@ def is_devices_enough(required): return len(jax.devices()) >= required +def is_devices_equal(required): + """ + Check if the available GPUs is exactly equal + """ + return len(jax.devices()) == required + + def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]: # Generate broadcast dims for drop_path. drop_path_shape = list(range(0, len(shape)))