Skip to content

Commit 8e97ea1

Browse files
authored
sliding window (#3570)
* add QPad3d unit test * Hard code sliding_window=2 and pass ut * change api; some bugs need to be fixed * sliding window pass ut * Support sliding_window=-1 * change api; refine code * change api; fix clang * enable sliding_window on flash_attn_varlen * fix accuracy issue for _exp_reduce_sum_fusion_kernel * add TORCH_CHECK to improve stability; improve code style * enable window_size_right * fix flake * add docs
1 parent 5f748ed commit 8e97ea1

File tree

9 files changed

+239
-25
lines changed

9 files changed

+239
-25
lines changed

csrc/cpu/aten/PagedAttention.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ void single_query_cached_kv_attention_forward_cpu(
2525
int64_t block_size,
2626
int64_t max_context_len,
2727
const c10::optional<at::Tensor>& alibi_slopes,
28+
int64_t window_size,
2829
const double k_scale,
2930
const double v_scale,
3031
const double softcap) {
@@ -41,6 +42,7 @@ void single_query_cached_kv_attention_forward_cpu(
4142
block_size,
4243
max_context_len,
4344
alibi_slopes,
45+
window_size,
4446
k_scale,
4547
v_scale,
4648
softcap);
@@ -71,6 +73,8 @@ void flash_attn_varlen_cpu(
7173
bool is_causal,
7274
at::Tensor& block_table,
7375
const c10::optional<at::Tensor>& alibi_slopes,
76+
int64_t window_size_left,
77+
int64_t window_size_right,
7478
const double k_scale,
7579
const double v_scale,
7680
const double softcap) {
@@ -88,6 +92,8 @@ void flash_attn_varlen_cpu(
8892
is_causal,
8993
block_table,
9094
alibi_slopes,
95+
window_size_left,
96+
window_size_right,
9197
k_scale,
9298
v_scale,
9399
softcap);

csrc/cpu/aten/PagedAttention.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ void single_query_cached_kv_attention(
2020
int64_t block_size,
2121
int64_t max_context_len,
2222
const c10::optional<at::Tensor>& alibi_slopes,
23+
int64_t window_size,
2324
const double k_scale,
2425
const double v_scale,
2526
const double softcap);
@@ -47,6 +48,8 @@ void flash_attn_varlen(
4748
bool is_causal,
4849
at::Tensor& block_table,
4950
const c10::optional<at::Tensor>& alibi_slopes,
51+
int64_t window_size_left,
52+
int64_t window_size_right,
5053
const double k_scale,
5154
const double v_scale,
5255
const double softcap);
@@ -63,6 +66,7 @@ using single_query_cached_kv_attention_fn = void (*)(
6366
int64_t block_size,
6467
int64_t max_context_len,
6568
const c10::optional<at::Tensor>& alibi_slopes,
69+
int64_t window_size,
6670
const double k_scale,
6771
const double v_scale,
6872
const double softcap);
@@ -89,6 +93,8 @@ using flash_attn_var_len_fn = void (*)(
8993
bool is_causal,
9094
at::Tensor& block_table,
9195
const c10::optional<at::Tensor>& alibi_slopes,
96+
int64_t window_size_left,
97+
int64_t window_size_right,
9298
const double k_scale,
9399
const double v_scale,
94100
const double softcap);

csrc/cpu/aten/kernels/PagedAttentionKrnl.cpp

Lines changed: 104 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ inline void _exp_reduce_sum_fusion_kernel(
413413
const int& size,
414414
T2* out,
415415
T1& val) {
416+
TORCH_CHECK(val != -std::numeric_limits<float>::infinity());
416417
auto vec_size = at::vec::Vectorized<T1>::size();
417418
auto vec_max = at::vec::Vectorized<T1>(val);
418419
T1 tmp_sum = 0;
@@ -1022,10 +1023,13 @@ void single_query_cached_kv_attention_kernel(
10221023
int64_t block_size,
10231024
int64_t max_context_len,
10241025
const c10::optional<at::Tensor>& alibi_slopes,
1026+
int64_t window_size,
10251027
const double k_scale,
10261028
const double v_scale,
10271029
const double softcap) {
10281030
bool use_softcap = softcap == -1 ? false : true;
1031+
// TODO: Support both use_softcap and window_size
1032+
TORCH_CHECK(!(window_size > 0 && use_softcap == true));
10291033
auto scale_ = use_softcap ? 1.0 : scale;
10301034
auto out_ptr = out.data_ptr<scalar_t>();
10311035
auto query_ptr = query.data_ptr<scalar_t>();
@@ -1068,6 +1072,21 @@ void single_query_cached_kv_attention_kernel(
10681072
{num_seqs, num_heads, max_num_partitions, head_size},
10691073
query.options().dtype(at::ScalarType::Float));
10701074

1075+
bool is_local = window_size > 0 && window_size < max_context_len;
1076+
if (is_local) {
1077+
max_logits = at::zeros(
1078+
{num_seqs, num_heads, max_num_partitions + 1},
1079+
query.options().dtype(at::ScalarType::Float));
1080+
1081+
exp_sum = at::zeros(
1082+
{num_seqs, num_heads, max_num_partitions + 1},
1083+
query.options().dtype(at::ScalarType::Float));
1084+
1085+
tmp_out = at::zeros(
1086+
{num_seqs, num_heads, max_num_partitions, head_size},
1087+
query.options().dtype(at::ScalarType::Float));
1088+
}
1089+
10711090
auto tmp_out_ptr = tmp_out.data_ptr<float>();
10721091
auto max_logits_ptr = max_logits.data_ptr<float>();
10731092
auto exp_sum_ptr = exp_sum.data_ptr<float>();
@@ -1109,6 +1128,9 @@ void single_query_cached_kv_attention_kernel(
11091128
continue;
11101129
auto partition_end =
11111130
std::min(partition_start + PARTITION_SIZE, context_len);
1131+
long sliding_window_start = is_local ? context_len - window_size : -1;
1132+
if (is_local && partition_end < sliding_window_start)
1133+
continue;
11121134
auto token_num = partition_end - partition_start;
11131135
auto block_num = (token_num + block_size - 1) / block_size;
11141136
auto logical_block_start = partition_start / block_size;
@@ -1141,13 +1163,20 @@ void single_query_cached_kv_attention_kernel(
11411163
auto k_cache_start = key_cache_ptr +
11421164
physical_block_id * kv_block_strideN +
11431165
block_offset * kv_block_strideP + kv_head_id * kv_block_strideH;
1144-
reduce_head(
1145-
q_ptr_start,
1146-
kv_head_group_size,
1147-
k_cache_start,
1148-
&(logits[logits_position]),
1149-
PARTITION_SIZE,
1150-
head_size);
1166+
if (is_local && token_id < sliding_window_start) {
1167+
for (auto i = 0; i < kv_head_group_size; i++) {
1168+
logits[logits_position + i * PARTITION_SIZE] =
1169+
-std::numeric_limits<float>::infinity();
1170+
}
1171+
} else {
1172+
reduce_head(
1173+
q_ptr_start,
1174+
kv_head_group_size,
1175+
k_cache_start,
1176+
&(logits[logits_position]),
1177+
PARTITION_SIZE,
1178+
head_size);
1179+
}
11511180
logits_position++;
11521181
}
11531182
}
@@ -1182,6 +1211,9 @@ void single_query_cached_kv_attention_kernel(
11821211
}
11831212
max_logits_ptr[max_logits_offset + hi * max_logits_strideH] =
11841213
partition_max;
1214+
if (partition_max == -std::numeric_limits<float>::infinity()) {
1215+
partition_max = 0;
1216+
}
11851217
_exp_reduce_sum_fusion_kernel<float, float>(
11861218
logits + hi * PARTITION_SIZE,
11871219
token_num,
@@ -1423,6 +1455,8 @@ void flash_attn_varlen_kernel(
14231455
bool is_causal, // whether the attention is causal
14241456
at::Tensor& block_table,
14251457
const c10::optional<at::Tensor>& alibi_slopes,
1458+
int64_t window_size_left,
1459+
int64_t window_size_right,
14261460
const double k_scale,
14271461
const double v_scale,
14281462
const double softcap) {
@@ -1449,6 +1483,17 @@ void flash_attn_varlen_kernel(
14491483
auto qSliceMax = (max_seqlen_q + qSplitSize - 1) / qSplitSize;
14501484
auto kvSliceMax = (max_seqlens_k + kvSplitSize - 1) / kvSplitSize;
14511485

1486+
if (is_causal) {
1487+
window_size_right = 0;
1488+
}
1489+
if (window_size_left >= max_seqlens_k) {
1490+
window_size_left = -1;
1491+
}
1492+
if (window_size_right >= max_seqlens_k) {
1493+
window_size_right = -1;
1494+
}
1495+
bool is_local = (window_size_left != -1) | (window_size_right != -1);
1496+
14521497
constexpr bool is_reduced_type =
14531498
at::vec::is_reduced_floating_point_v<scalar_t>;
14541499
using accum_t = at::opmath_type<scalar_t>;
@@ -1534,6 +1579,14 @@ void flash_attn_varlen_kernel(
15341579
physical_block_id * kv_block_strideN +
15351580
kv_head_id * kv_block_strideH;
15361581
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
1582+
if (window_size_left > 0 and
1583+
m + context_len - window_size_left > n + kvBlockSize) {
1584+
continue;
1585+
}
1586+
if (window_size_right >= 0 and
1587+
m + context_len + qBlockSize + window_size_right + 1 <= n) {
1588+
continue;
1589+
}
15371590
// Calculate the scale * query * key
15381591
// query block[qBlockSize, head_size], key block: [kvBlockSize,
15391592
// head_size]
@@ -1563,17 +1616,25 @@ void flash_attn_varlen_kernel(
15631616
scaling_factor);
15641617
}
15651618
}
1566-
// apply causal mask, fill unmasked position with -inf
1567-
if (is_causal) {
1619+
1620+
// apply mask, fill unmasked position with -inf
1621+
if (is_local) {
15681622
for (int64_t q = 0; q < qBlockSize; q++) {
15691623
for (int64_t p = 0; p < kvBlockSize; p++) {
1570-
if (m + q + context_len < n + p) {
1624+
int64_t idx = context_len + m + q;
1625+
if (window_size_left > 0 and idx - window_size_left > n + p) {
1626+
qk_data[q * kvSplitSize + p] =
1627+
-std::numeric_limits<accum_t>::infinity();
1628+
}
1629+
if (window_size_right >= 0 and
1630+
idx + window_size_right + 1 <= n + p) {
15711631
qk_data[q * kvSplitSize + p] =
15721632
-std::numeric_limits<accum_t>::infinity();
15731633
}
15741634
}
15751635
}
15761636
}
1637+
15771638
// Calculate max and sum of exp(val-max)
15781639
for (int64_t q = 0; q < qBlockSize; q++) {
15791640
accum_t tmp_max = -std::numeric_limits<accum_t>::infinity(),
@@ -1587,15 +1648,21 @@ void flash_attn_varlen_kernel(
15871648
tmp_max);
15881649

15891650
tmp_max = qk_max_data[q] > tmp_max ? qk_max_data[q] : tmp_max;
1590-
tmp_sum = tmp_max;
1651+
tmp_sum = tmp_max != -std::numeric_limits<accum_t>::infinity()
1652+
? tmp_max
1653+
: 0;
15911654
_exp_reduce_sum_fusion_kernel<accum_t, scalar_t>(
15921655
qk_data + q * kvSplitSize,
15931656
kvBlockSize,
15941657
conditional_data_ptr(qk_data, qk_reduced_data) +
15951658
q * kvSplitSize,
15961659
tmp_sum);
15971660
// exp_tmp <- exp(max[row] - max)
1598-
exp_tmp = std::exp(qk_max_data[q] - tmp_max);
1661+
if (tmp_max == -std::numeric_limits<accum_t>::infinity()) {
1662+
exp_tmp = std::exp(qk_max_data[q]);
1663+
} else {
1664+
exp_tmp = std::exp(qk_max_data[q] - tmp_max);
1665+
}
15991666
// sum[row] <- sum + exp_tmp * sum[row]
16001667
qk_sum_data[q] = tmp_sum + exp_tmp * qk_sum_data[q];
16011668
// max[row] <- max
@@ -1656,6 +1723,7 @@ void single_query_cached_kv_attention_kernel_impl(
16561723
int64_t block_size,
16571724
int64_t max_context_len,
16581725
const c10::optional<at::Tensor>& alibi_slopes,
1726+
int64_t window_size,
16591727
const double k_scale,
16601728
const double v_scale,
16611729
const double softcap) {
@@ -1739,6 +1807,7 @@ void single_query_cached_kv_attention_kernel_impl(
17391807
block_size,
17401808
max_context_len,
17411809
alibi_slopes,
1810+
window_size,
17421811
k_scale,
17431812
v_scale,
17441813
softcap);
@@ -1754,6 +1823,7 @@ void single_query_cached_kv_attention_kernel_impl(
17541823
block_size,
17551824
max_context_len,
17561825
alibi_slopes,
1826+
window_size,
17571827
k_scale,
17581828
v_scale,
17591829
softcap);
@@ -1769,6 +1839,7 @@ void single_query_cached_kv_attention_kernel_impl(
17691839
block_size,
17701840
max_context_len,
17711841
alibi_slopes,
1842+
window_size,
17721843
k_scale,
17731844
v_scale,
17741845
softcap);
@@ -1784,6 +1855,7 @@ void single_query_cached_kv_attention_kernel_impl(
17841855
block_size,
17851856
max_context_len,
17861857
alibi_slopes,
1858+
window_size,
17871859
k_scale,
17881860
v_scale,
17891861
softcap);
@@ -1849,6 +1921,8 @@ void flash_attn_varlen_cpu_kernel_impl(
18491921
bool is_causal,
18501922
at::Tensor& block_table,
18511923
const c10::optional<at::Tensor>& alibi_slopes,
1924+
int64_t window_size_left,
1925+
int64_t window_size_right,
18521926
const double k_scale,
18531927
const double v_scale,
18541928
const double softcap) {
@@ -1858,9 +1932,6 @@ void flash_attn_varlen_cpu_kernel_impl(
18581932
TORCH_CHECK(
18591933
!alibi_slopes.has_value(),
18601934
"alibi_slopes is not supported for flash_attn_varlen yet");
1861-
TORCH_CHECK(
1862-
is_causal,
1863-
"flash_attn_varlen_cpu_kernel_impl only supports causal attention, pls use the is_causal=True");
18641935
TORCH_CHECK(
18651936
query.scalar_type() == out.scalar_type(),
18661937
"query and out should have the same data type");
@@ -1882,6 +1953,8 @@ void flash_attn_varlen_cpu_kernel_impl(
18821953
is_causal,
18831954
block_table,
18841955
alibi_slopes,
1956+
window_size_left,
1957+
window_size_right,
18851958
k_scale,
18861959
v_scale,
18871960
softcap);
@@ -1899,6 +1972,8 @@ void flash_attn_varlen_cpu_kernel_impl(
18991972
is_causal,
19001973
block_table,
19011974
alibi_slopes,
1975+
window_size_left,
1976+
window_size_right,
19021977
k_scale,
19031978
v_scale,
19041979
softcap);
@@ -1916,6 +1991,8 @@ void flash_attn_varlen_cpu_kernel_impl(
19161991
is_causal,
19171992
block_table,
19181993
alibi_slopes,
1994+
window_size_left,
1995+
window_size_right,
19191996
k_scale,
19201997
v_scale,
19211998
softcap);
@@ -1936,6 +2013,8 @@ void flash_attn_varlen_cpu_kernel_impl(
19362013
is_causal,
19372014
block_table,
19382015
alibi_slopes,
2016+
window_size_left,
2017+
window_size_right,
19392018
k_scale,
19402019
v_scale,
19412020
softcap);
@@ -1953,6 +2032,8 @@ void flash_attn_varlen_cpu_kernel_impl(
19532032
is_causal,
19542033
block_table,
19552034
alibi_slopes,
2035+
window_size_left,
2036+
window_size_right,
19562037
k_scale,
19572038
v_scale,
19582039
softcap);
@@ -1970,6 +2051,8 @@ void flash_attn_varlen_cpu_kernel_impl(
19702051
is_causal,
19712052
block_table,
19722053
alibi_slopes,
2054+
window_size_left,
2055+
window_size_right,
19732056
k_scale,
19742057
v_scale,
19752058
softcap);
@@ -1990,6 +2073,8 @@ void flash_attn_varlen_cpu_kernel_impl(
19902073
is_causal,
19912074
block_table,
19922075
alibi_slopes,
2076+
window_size_left,
2077+
window_size_right,
19932078
k_scale,
19942079
v_scale,
19952080
softcap);
@@ -2007,6 +2092,8 @@ void flash_attn_varlen_cpu_kernel_impl(
20072092
is_causal,
20082093
block_table,
20092094
alibi_slopes,
2095+
window_size_left,
2096+
window_size_right,
20102097
k_scale,
20112098
v_scale,
20122099
softcap);
@@ -2024,6 +2111,8 @@ void flash_attn_varlen_cpu_kernel_impl(
20242111
is_causal,
20252112
block_table,
20262113
alibi_slopes,
2114+
window_size_left,
2115+
window_size_right,
20272116
k_scale,
20282117
v_scale,
20292118
softcap);

0 commit comments

Comments
 (0)