Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 100 additions & 28 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,82 @@ def _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
)


def run_length_fill_flattened(segment_ids_flattened) -> jnp.ndarray:
"""
Returns an array of run-lengths of the flattened segment ids
"""
# Example for run_length_fill_flattened:
# Input segment_ids_flattened: [[1 1 2 2 2 0 3 0 4 4 4 4 4 0 0 0], [1 0 0 2 2 2 0 0 3 3 4 4 4 4 0 0]]
# run_ids: [[0 0 1 1 1 2 3 4 5 5 5 5 5 6 6 6], [0 1 1 2 2 2 3 3 4 4 5 5 5 5 6 6]]
# counts: [[2 3 1 1 1 5 3 0 0 0 0 0 0 0 0 0], [1 2 3 2 2 4 2 0 0 0 0 0 0 0 0 0]]
# Returns segment_ids_run_length_1d: [[2 2 3 3 3 0 1 0 5 5 5 5 5 0 0 0], [1 0 0 3 3 3 0 0 2 2 4 4 4 4 0 0]]
boundary = jnp.concatenate(
[jnp.broadcast_to(True, (1,)), segment_ids_flattened[1:] != segment_ids_flattened[:-1]]
)
run_ids = jnp.cumsum(boundary) - 1
# Each element could, in worst case, start a run
max_runs = segment_ids_flattened.shape[-1]
counts = jnp.bincount(run_ids, length=max_runs)
# Fill in the missing values
segment_ids_run_length_1d = counts[run_ids]
segment_ids_run_length_1d = jnp.where(segment_ids_flattened == 0, 0, segment_ids_run_length_1d)
return segment_ids_run_length_1d


def run_length_fill(segment_ids) -> jnp.ndarray:
"""
Returns an array of run-lengths of the segment ids, with shape preserved
"""
# Example for run_length_fill:
# Input segment_ids: [[1 1 2 2 2 0 3 0 4 4 4 4 4 0 0 0], [1 0 0 2 2 2 0 0 3 3 4 4 4 4 0 0]]
# Returns run length: [[2 2 3 3 3 0 1 0 5 5 5 5 5 0 0 0], [1 0 0 3 3 3 0 0 2 2 4 4 4 4 0 0]]
# Flatten all dimension except the last one prior to executing vmap run length
orig_shape = segment_ids.shape
segment_ids_flat = segment_ids.reshape(-1, orig_shape[-1])
run_length_segment_id_shape = jax.vmap(run_length_fill_flattened, in_axes=0)(segment_ids_flat)
return run_length_segment_id_shape.reshape(orig_shape)


def _get_seqlens_thd(segment_ids, max_segments_per_seq):
# Create mask for non-zero seg ids and get the non-zero indices associated with the same
non_zero_mask = segment_ids != 0
max_size = segment_ids.shape[-1]
non_zero_indices = jax.vmap(
lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0]
)(non_zero_mask)

# Pick non-zero seg ids and seg pos using take_along_axis to index within the seg ids and pos
# Clip -1 to 0 for safe indexing
clipped_indices = jnp.clip(non_zero_indices, 0, None)
valid_segment_ids = jnp.where(
non_zero_indices >= 0, jnp.take_along_axis(segment_ids, clipped_indices, axis=-1), 0
)
seqlens_all = jax.vmap(
lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:]
)(valid_segment_ids)
seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all)
return seqlens_all_pad_neg


def _get_seqoffsets_thd(segment_ids, segment_pos, max_segments_per_seq):
segment_changes = jnp.concatenate(
[
jnp.full(
(segment_pos.shape[0], 1), True, dtype=bool
), # First valid element starts a segment
(segment_pos[..., 1:] != segment_pos[..., :-1] + 1), # Segment pos changed
],
axis=-1,
)
# Remove any padded region segment changes
segment_changes_masked = jnp.where(segment_ids != 0, segment_changes, False)
# Get the indices for segment changes (these are the offsets)
seq_offsets = jax.vmap(
lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq + 1, fill_value=-1)[0]
)(segment_changes_masked)
return seq_offsets


def _segment_ids_pos_to_seqlens_offsets(
segment_ids_q,
segment_ids_kv,
Expand All @@ -443,37 +519,33 @@ def _segment_ids_pos_to_seqlens_offsets(
#
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements.
if attn_mask_type.is_causal() and window_size is None or window_size == (-1, -1):
return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
)

# (1 = attend, 0 = masked)
segment_mask = make_attention_mask(
segment_ids_q,
segment_ids_kv,
jnp.equal,
# For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation
# using the segment ids and pos along with mask type (causal or brcm) is sufficient.
# It does not need to involve SW for this mask's creation

# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
# if (attn_mask_type.is_causal() and window_size is None) or (
# window_size == (-1, -1) and not attn_mask_type.is_bottom_right()
# ):
# return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
# segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
# )
q_seqlen = _get_seqlens_thd(
segment_ids=segment_ids_q, max_segments_per_seq=max_segments_per_seq
)
segment_mask_with_id = make_attention_mask(
segment_ids_q,
segment_ids_kv,
lambda x, y: jnp.equal(x, y) * x,
kv_seqlen = _get_seqlens_thd(
segment_ids=segment_ids_kv, max_segments_per_seq=max_segments_per_seq
)
attn_mask = segment_mask
if attn_mask_type.is_causal():
causal_mask = make_attention_mask(
segment_pos_q,
segment_pos_kv,
jnp.greater_equal,
)
attn_mask = jnp.logical_and(segment_mask, causal_mask)

swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool)
attn_mask = jnp.logical_and(attn_mask, swa_mask)

attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0)
q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset(
attn_mask_with_id, max_segments_per_seq
q_offset = _get_seqoffsets_thd(
segment_ids=segment_ids_q,
segment_pos=segment_pos_q,
max_segments_per_seq=max_segments_per_seq,
)
kv_offset = _get_seqoffsets_thd(
segment_ids=segment_ids_kv,
segment_pos=segment_pos_kv,
max_segments_per_seq=max_segments_per_seq,
)
return q_seqlen, kv_seqlen, q_offset, kv_offset

Expand Down