@@ -413,6 +413,7 @@ inline void _exp_reduce_sum_fusion_kernel(
413413 const int & size,
414414 T2* out,
415415 T1& val) {
416+ TORCH_CHECK (val != -std::numeric_limits<float >::infinity ());
416417 auto vec_size = at::vec::Vectorized<T1>::size ();
417418 auto vec_max = at::vec::Vectorized<T1>(val);
418419 T1 tmp_sum = 0 ;
@@ -1022,10 +1023,13 @@ void single_query_cached_kv_attention_kernel(
10221023 int64_t block_size,
10231024 int64_t max_context_len,
10241025 const c10::optional<at::Tensor>& alibi_slopes,
1026+ int64_t window_size,
10251027 const double k_scale,
10261028 const double v_scale,
10271029 const double softcap) {
10281030 bool use_softcap = softcap == -1 ? false : true ;
1031+ // TODO: Support both use_softcap and window_size
1032+ TORCH_CHECK (!(window_size > 0 && use_softcap == true ));
10291033 auto scale_ = use_softcap ? 1.0 : scale;
10301034 auto out_ptr = out.data_ptr <scalar_t >();
10311035 auto query_ptr = query.data_ptr <scalar_t >();
@@ -1068,6 +1072,21 @@ void single_query_cached_kv_attention_kernel(
10681072 {num_seqs, num_heads, max_num_partitions, head_size},
10691073 query.options ().dtype (at::ScalarType::Float));
10701074
1075+ bool is_local = window_size > 0 && window_size < max_context_len;
1076+ if (is_local) {
1077+ max_logits = at::zeros (
1078+ {num_seqs, num_heads, max_num_partitions + 1 },
1079+ query.options ().dtype (at::ScalarType::Float));
1080+
1081+ exp_sum = at::zeros (
1082+ {num_seqs, num_heads, max_num_partitions + 1 },
1083+ query.options ().dtype (at::ScalarType::Float));
1084+
1085+ tmp_out = at::zeros (
1086+ {num_seqs, num_heads, max_num_partitions, head_size},
1087+ query.options ().dtype (at::ScalarType::Float));
1088+ }
1089+
10711090 auto tmp_out_ptr = tmp_out.data_ptr <float >();
10721091 auto max_logits_ptr = max_logits.data_ptr <float >();
10731092 auto exp_sum_ptr = exp_sum.data_ptr <float >();
@@ -1109,6 +1128,9 @@ void single_query_cached_kv_attention_kernel(
11091128 continue ;
11101129 auto partition_end =
11111130 std::min (partition_start + PARTITION_SIZE, context_len);
1131+ long sliding_window_start = is_local ? context_len - window_size : -1 ;
1132+ if (is_local && partition_end < sliding_window_start)
1133+ continue ;
11121134 auto token_num = partition_end - partition_start;
11131135 auto block_num = (token_num + block_size - 1 ) / block_size;
11141136 auto logical_block_start = partition_start / block_size;
@@ -1141,13 +1163,20 @@ void single_query_cached_kv_attention_kernel(
11411163 auto k_cache_start = key_cache_ptr +
11421164 physical_block_id * kv_block_strideN +
11431165 block_offset * kv_block_strideP + kv_head_id * kv_block_strideH;
1144- reduce_head (
1145- q_ptr_start,
1146- kv_head_group_size,
1147- k_cache_start,
1148- &(logits[logits_position]),
1149- PARTITION_SIZE,
1150- head_size);
1166+ if (is_local && token_id < sliding_window_start) {
1167+ for (auto i = 0 ; i < kv_head_group_size; i++) {
1168+ logits[logits_position + i * PARTITION_SIZE] =
1169+ -std::numeric_limits<float >::infinity ();
1170+ }
1171+ } else {
1172+ reduce_head (
1173+ q_ptr_start,
1174+ kv_head_group_size,
1175+ k_cache_start,
1176+ &(logits[logits_position]),
1177+ PARTITION_SIZE,
1178+ head_size);
1179+ }
11511180 logits_position++;
11521181 }
11531182 }
@@ -1182,6 +1211,9 @@ void single_query_cached_kv_attention_kernel(
11821211 }
11831212 max_logits_ptr[max_logits_offset + hi * max_logits_strideH] =
11841213 partition_max;
1214+ if (partition_max == -std::numeric_limits<float >::infinity ()) {
1215+ partition_max = 0 ;
1216+ }
11851217 _exp_reduce_sum_fusion_kernel<float , float >(
11861218 logits + hi * PARTITION_SIZE,
11871219 token_num,
@@ -1423,6 +1455,8 @@ void flash_attn_varlen_kernel(
14231455 bool is_causal, // whether the attention is causal
14241456 at::Tensor& block_table,
14251457 const c10::optional<at::Tensor>& alibi_slopes,
1458+ int64_t window_size_left,
1459+ int64_t window_size_right,
14261460 const double k_scale,
14271461 const double v_scale,
14281462 const double softcap) {
@@ -1449,6 +1483,17 @@ void flash_attn_varlen_kernel(
14491483 auto qSliceMax = (max_seqlen_q + qSplitSize - 1 ) / qSplitSize;
14501484 auto kvSliceMax = (max_seqlens_k + kvSplitSize - 1 ) / kvSplitSize;
14511485
1486+ if (is_causal) {
1487+ window_size_right = 0 ;
1488+ }
1489+ if (window_size_left >= max_seqlens_k) {
1490+ window_size_left = -1 ;
1491+ }
1492+ if (window_size_right >= max_seqlens_k) {
1493+ window_size_right = -1 ;
1494+ }
1495+ bool is_local = (window_size_left != -1 ) | (window_size_right != -1 );
1496+
14521497 constexpr bool is_reduced_type =
14531498 at::vec::is_reduced_floating_point_v<scalar_t >;
14541499 using accum_t = at::opmath_type<scalar_t >;
@@ -1534,6 +1579,14 @@ void flash_attn_varlen_kernel(
15341579 physical_block_id * kv_block_strideN +
15351580 kv_head_id * kv_block_strideH;
15361581 int64_t kvBlockSize = std::min (kvSplitSize, kvSize - n);
1582+ if (window_size_left > 0 and
1583+ m + context_len - window_size_left > n + kvBlockSize) {
1584+ continue ;
1585+ }
1586+ if (window_size_right >= 0 and
1587+ m + context_len + qBlockSize + window_size_right + 1 <= n) {
1588+ continue ;
1589+ }
15371590 // Calculate the scale * query * key
15381591 // query block[qBlockSize, head_size], key block: [kvBlockSize,
15391592 // head_size]
@@ -1563,17 +1616,25 @@ void flash_attn_varlen_kernel(
15631616 scaling_factor);
15641617 }
15651618 }
1566- // apply causal mask, fill unmasked position with -inf
1567- if (is_causal) {
1619+
1620+ // apply mask, fill unmasked position with -inf
1621+ if (is_local) {
15681622 for (int64_t q = 0 ; q < qBlockSize; q++) {
15691623 for (int64_t p = 0 ; p < kvBlockSize; p++) {
1570- if (m + q + context_len < n + p) {
1624+ int64_t idx = context_len + m + q;
1625+ if (window_size_left > 0 and idx - window_size_left > n + p) {
1626+ qk_data[q * kvSplitSize + p] =
1627+ -std::numeric_limits<accum_t >::infinity ();
1628+ }
1629+ if (window_size_right >= 0 and
1630+ idx + window_size_right + 1 <= n + p) {
15711631 qk_data[q * kvSplitSize + p] =
15721632 -std::numeric_limits<accum_t >::infinity ();
15731633 }
15741634 }
15751635 }
15761636 }
1637+
15771638 // Calculate max and sum of exp(val-max)
15781639 for (int64_t q = 0 ; q < qBlockSize; q++) {
15791640 accum_t tmp_max = -std::numeric_limits<accum_t >::infinity (),
@@ -1587,15 +1648,21 @@ void flash_attn_varlen_kernel(
15871648 tmp_max);
15881649
15891650 tmp_max = qk_max_data[q] > tmp_max ? qk_max_data[q] : tmp_max;
1590- tmp_sum = tmp_max;
1651+ tmp_sum = tmp_max != -std::numeric_limits<accum_t >::infinity ()
1652+ ? tmp_max
1653+ : 0 ;
15911654 _exp_reduce_sum_fusion_kernel<accum_t , scalar_t >(
15921655 qk_data + q * kvSplitSize,
15931656 kvBlockSize,
15941657 conditional_data_ptr (qk_data, qk_reduced_data) +
15951658 q * kvSplitSize,
15961659 tmp_sum);
15971660 // exp_tmp <- exp(max[row] - max)
1598- exp_tmp = std::exp (qk_max_data[q] - tmp_max);
1661+ if (tmp_max == -std::numeric_limits<accum_t >::infinity ()) {
1662+ exp_tmp = std::exp (qk_max_data[q]);
1663+ } else {
1664+ exp_tmp = std::exp (qk_max_data[q] - tmp_max);
1665+ }
15991666 // sum[row] <- sum + exp_tmp * sum[row]
16001667 qk_sum_data[q] = tmp_sum + exp_tmp * qk_sum_data[q];
16011668 // max[row] <- max
@@ -1656,6 +1723,7 @@ void single_query_cached_kv_attention_kernel_impl(
16561723 int64_t block_size,
16571724 int64_t max_context_len,
16581725 const c10::optional<at::Tensor>& alibi_slopes,
1726+ int64_t window_size,
16591727 const double k_scale,
16601728 const double v_scale,
16611729 const double softcap) {
@@ -1739,6 +1807,7 @@ void single_query_cached_kv_attention_kernel_impl(
17391807 block_size,
17401808 max_context_len,
17411809 alibi_slopes,
1810+ window_size,
17421811 k_scale,
17431812 v_scale,
17441813 softcap);
@@ -1754,6 +1823,7 @@ void single_query_cached_kv_attention_kernel_impl(
17541823 block_size,
17551824 max_context_len,
17561825 alibi_slopes,
1826+ window_size,
17571827 k_scale,
17581828 v_scale,
17591829 softcap);
@@ -1769,6 +1839,7 @@ void single_query_cached_kv_attention_kernel_impl(
17691839 block_size,
17701840 max_context_len,
17711841 alibi_slopes,
1842+ window_size,
17721843 k_scale,
17731844 v_scale,
17741845 softcap);
@@ -1784,6 +1855,7 @@ void single_query_cached_kv_attention_kernel_impl(
17841855 block_size,
17851856 max_context_len,
17861857 alibi_slopes,
1858+ window_size,
17871859 k_scale,
17881860 v_scale,
17891861 softcap);
@@ -1849,6 +1921,8 @@ void flash_attn_varlen_cpu_kernel_impl(
18491921 bool is_causal,
18501922 at::Tensor& block_table,
18511923 const c10::optional<at::Tensor>& alibi_slopes,
1924+ int64_t window_size_left,
1925+ int64_t window_size_right,
18521926 const double k_scale,
18531927 const double v_scale,
18541928 const double softcap) {
@@ -1858,9 +1932,6 @@ void flash_attn_varlen_cpu_kernel_impl(
18581932 TORCH_CHECK (
18591933 !alibi_slopes.has_value (),
18601934 " alibi_slopes is not supported for flash_attn_varlen yet" );
1861- TORCH_CHECK (
1862- is_causal,
1863- " flash_attn_varlen_cpu_kernel_impl only supports causal attention, pls use the is_causal=True" );
18641935 TORCH_CHECK (
18651936 query.scalar_type () == out.scalar_type (),
18661937 " query and out should have the same data type" );
@@ -1882,6 +1953,8 @@ void flash_attn_varlen_cpu_kernel_impl(
18821953 is_causal,
18831954 block_table,
18841955 alibi_slopes,
1956+ window_size_left,
1957+ window_size_right,
18851958 k_scale,
18861959 v_scale,
18871960 softcap);
@@ -1899,6 +1972,8 @@ void flash_attn_varlen_cpu_kernel_impl(
18991972 is_causal,
19001973 block_table,
19011974 alibi_slopes,
1975+ window_size_left,
1976+ window_size_right,
19021977 k_scale,
19031978 v_scale,
19041979 softcap);
@@ -1916,6 +1991,8 @@ void flash_attn_varlen_cpu_kernel_impl(
19161991 is_causal,
19171992 block_table,
19181993 alibi_slopes,
1994+ window_size_left,
1995+ window_size_right,
19191996 k_scale,
19201997 v_scale,
19211998 softcap);
@@ -1936,6 +2013,8 @@ void flash_attn_varlen_cpu_kernel_impl(
19362013 is_causal,
19372014 block_table,
19382015 alibi_slopes,
2016+ window_size_left,
2017+ window_size_right,
19392018 k_scale,
19402019 v_scale,
19412020 softcap);
@@ -1953,6 +2032,8 @@ void flash_attn_varlen_cpu_kernel_impl(
19532032 is_causal,
19542033 block_table,
19552034 alibi_slopes,
2035+ window_size_left,
2036+ window_size_right,
19562037 k_scale,
19572038 v_scale,
19582039 softcap);
@@ -1970,6 +2051,8 @@ void flash_attn_varlen_cpu_kernel_impl(
19702051 is_causal,
19712052 block_table,
19722053 alibi_slopes,
2054+ window_size_left,
2055+ window_size_right,
19732056 k_scale,
19742057 v_scale,
19752058 softcap);
@@ -1990,6 +2073,8 @@ void flash_attn_varlen_cpu_kernel_impl(
19902073 is_causal,
19912074 block_table,
19922075 alibi_slopes,
2076+ window_size_left,
2077+ window_size_right,
19932078 k_scale,
19942079 v_scale,
19952080 softcap);
@@ -2007,6 +2092,8 @@ void flash_attn_varlen_cpu_kernel_impl(
20072092 is_causal,
20082093 block_table,
20092094 alibi_slopes,
2095+ window_size_left,
2096+ window_size_right,
20102097 k_scale,
20112098 v_scale,
20122099 softcap);
@@ -2024,6 +2111,8 @@ void flash_attn_varlen_cpu_kernel_impl(
20242111 is_causal,
20252112 block_table,
20262113 alibi_slopes,
2114+ window_size_left,
2115+ window_size_right,
20272116 k_scale,
20282117 v_scale,
20292118 softcap);
0 commit comments