diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index fe4109cee81..fb3fa3cb9a8 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -420,6 +420,82 @@ def _segment_ids_pos_to_seqlens_offsets_fast_causal_path( ) +def run_length_fill_flattened(segment_ids_flattened) -> jnp.ndarray: + """ + Returns an array of run-lengths of the flattened segment ids + """ + # Example for run_length_fill_flattened: + # Input segment_ids_flattened: [[1 1 2 2 2 0 3 0 4 4 4 4 4 0 0 0], [1 0 0 2 2 2 0 0 3 3 4 4 4 4 0 0]] + # run_ids: [[0 0 1 1 1 2 3 4 5 5 5 5 5 6 6 6], [0 1 1 2 2 2 3 3 4 4 5 5 5 5 6 6]] + # counts: [[2 3 1 1 1 5 3 0 0 0 0 0 0 0 0 0], [1 2 3 2 2 4 2 0 0 0 0 0 0 0 0 0]] + # Returns segment_ids_run_length_1d: [[2 2 3 3 3 0 1 0 5 5 5 5 5 0 0 0], [1 0 0 3 3 3 0 0 2 2 4 4 4 4 0 0]] + boundary = jnp.concatenate( + [jnp.broadcast_to(True, (1,)), segment_ids_flattened[1:] != segment_ids_flattened[:-1]] + ) + run_ids = jnp.cumsum(boundary) - 1 + # Each element could, in worst case, start a run + max_runs = segment_ids_flattened.shape[-1] + counts = jnp.bincount(run_ids, length=max_runs) + # Fill in the missing values + segment_ids_run_length_1d = counts[run_ids] + segment_ids_run_length_1d = jnp.where(segment_ids_flattened == 0, 0, segment_ids_run_length_1d) + return segment_ids_run_length_1d + + +def run_length_fill(segment_ids) -> jnp.ndarray: + """ + Returns an array of run-lengths of the segment ids, with shape preserved + """ + # Example for run_length_fill: + # Input segment_ids: [[1 1 2 2 2 0 3 0 4 4 4 4 4 0 0 0], [1 0 0 2 2 2 0 0 3 3 4 4 4 4 0 0]] + # Returns run length: [[2 2 3 3 3 0 1 0 5 5 5 5 5 0 0 0], [1 0 0 3 3 3 0 0 2 2 4 4 4 4 0 0]] + # Flatten all dimension except the last one prior to executing vmap run length + orig_shape = segment_ids.shape + segment_ids_flat = segment_ids.reshape(-1, orig_shape[-1]) + run_length_segment_id_shape = jax.vmap(run_length_fill_flattened, in_axes=0)(segment_ids_flat) + return run_length_segment_id_shape.reshape(orig_shape) + + +def _get_seqlens_thd(segment_ids, max_segments_per_seq): + # Create mask for non-zero seg ids and get the non-zero indices associated with the same + non_zero_mask = segment_ids != 0 + max_size = segment_ids.shape[-1] + non_zero_indices = jax.vmap( + lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0] + )(non_zero_mask) + + # Pick non-zero seg ids and seg pos using take_along_axis to index within the seg ids and pos + # Clip -1 to 0 for safe indexing + clipped_indices = jnp.clip(non_zero_indices, 0, None) + valid_segment_ids = jnp.where( + non_zero_indices >= 0, jnp.take_along_axis(segment_ids, clipped_indices, axis=-1), 0 + ) + seqlens_all = jax.vmap( + lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:] + )(valid_segment_ids) + seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all) + return seqlens_all_pad_neg + + +def _get_seqoffsets_thd(segment_ids, segment_pos, max_segments_per_seq): + segment_changes = jnp.concatenate( + [ + jnp.full( + (segment_pos.shape[0], 1), True, dtype=bool + ), # First valid element starts a segment + (segment_pos[..., 1:] != segment_pos[..., :-1] + 1), # Segment pos changed + ], + axis=-1, + ) + # Remove any padded region segment changes + segment_changes_masked = jnp.where(segment_ids != 0, segment_changes, False) + # Get the indices for segment changes (these are the offsets) + seq_offsets = jax.vmap( + lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq + 1, fill_value=-1)[0] + )(segment_changes_masked) + return seq_offsets + + def _segment_ids_pos_to_seqlens_offsets( segment_ids_q, segment_ids_kv, @@ -443,37 +519,33 @@ def _segment_ids_pos_to_seqlens_offsets( # # This fast path avoids expanding the mask to Q * KV matrix and instead allows us to # examine only O(Q+KV) elements. - if attn_mask_type.is_causal() and window_size is None or window_size == (-1, -1): - return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( - segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq - ) - # (1 = attend, 0 = masked) - segment_mask = make_attention_mask( - segment_ids_q, - segment_ids_kv, - jnp.equal, + # For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation + # using the segment ids and pos along with mask type (causal or brcm) is sufficient. + # It does not need to involve SW for this mask's creation + + # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well + # if (attn_mask_type.is_causal() and window_size is None) or ( + # window_size == (-1, -1) and not attn_mask_type.is_bottom_right() + # ): + # return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( + # segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq + # ) + q_seqlen = _get_seqlens_thd( + segment_ids=segment_ids_q, max_segments_per_seq=max_segments_per_seq ) - segment_mask_with_id = make_attention_mask( - segment_ids_q, - segment_ids_kv, - lambda x, y: jnp.equal(x, y) * x, + kv_seqlen = _get_seqlens_thd( + segment_ids=segment_ids_kv, max_segments_per_seq=max_segments_per_seq ) - attn_mask = segment_mask - if attn_mask_type.is_causal(): - causal_mask = make_attention_mask( - segment_pos_q, - segment_pos_kv, - jnp.greater_equal, - ) - attn_mask = jnp.logical_and(segment_mask, causal_mask) - - swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool) - attn_mask = jnp.logical_and(attn_mask, swa_mask) - - attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0) - q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset( - attn_mask_with_id, max_segments_per_seq + q_offset = _get_seqoffsets_thd( + segment_ids=segment_ids_q, + segment_pos=segment_pos_q, + max_segments_per_seq=max_segments_per_seq, + ) + kv_offset = _get_seqoffsets_thd( + segment_ids=segment_ids_kv, + segment_pos=segment_pos_kv, + max_segments_per_seq=max_segments_per_seq, ) return q_seqlen, kv_seqlen, q_offset, kv_offset