diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index c56c036f..2129a63f 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -57,7 +57,7 @@ namespace gcpp { namespace HWY_NAMESPACE { static constexpr size_t kNFx8HTileSize = 8; - +static constexpr float kNegInf = -std::numeric_limits::max() / 64.0f; // Transposes q into q_t. // Both are 4D tensors stuffed into a 2-D MatPtrT. // q has shape [batch, qbatch][head, qkv_dim]. @@ -467,7 +467,7 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap( const DF4 df4; using VF4 = hn::Vec; static_assert(kNumQueries >= 1 && kNumQueries <= 4); - VF4 new_max = hn::Set(df4, -std::numeric_limits::max() / 2.0f); + VF4 new_max = hn::Set(df4, kNegInf); VF max_0, max_1, max_2, max_3 = hn::Zero(df); max_0 = hn::Max(x_0_p0, x_0_p1); if constexpr (kNumQueries >= 2) { @@ -490,29 +490,29 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap( VF4 one_over_cap = hn::Set(df4, one_over_att_cap); new_max = hn::Mul(cap, hn::Tanh(df4, hn::Mul(new_max, one_over_cap))); } - VF4 old_max_vf = hn::Set(df4, -std::numeric_limits::max() / 2.0f); + VF4 old_max_vf = hn::Set(df4, kNegInf); old_max_vf = hn::LoadU(df4, old_max); new_max = hn::Max(new_max, old_max_vf); + auto changed_max = hn::Gt(new_max, hn::Set(df4, kNegInf)); // TODO figure out what was wrong with broadcasts and change to that. - HWY_ALIGN float tmp_max[4]; - hn::Store(new_max, df4, tmp_max); + hn::StoreU(new_max, df4, old_max); if constexpr (kNumQueries >= 1) { - const VF new_max_0 = hn::Set(df, tmp_max[0]); + const VF new_max_0 = hn::Set(df, old_max[0]); x_0_p0 = hn::Exp(df, hn::Sub(x_0_p0, new_max_0)); x_0_p1 = hn::Exp(df, hn::Sub(x_0_p1, new_max_0)); } if constexpr (kNumQueries >= 2) { - const VF new_max_0 = hn::Set(df, tmp_max[1]); + const VF new_max_0 = hn::Set(df, old_max[1]); x_1_p0 = hn::Exp(df, hn::Sub(x_1_p0, new_max_0)); x_1_p1 = hn::Exp(df, hn::Sub(x_1_p1, new_max_0)); } if constexpr (kNumQueries >= 3) { - const VF new_max_0 = hn::Set(df, tmp_max[2]); + const VF new_max_0 = hn::Set(df, old_max[2]); x_2_p0 = hn::Exp(df, hn::Sub(x_2_p0, new_max_0)); x_2_p1 = hn::Exp(df, hn::Sub(x_2_p1, new_max_0)); } if constexpr (kNumQueries >= 4) { - const VF new_max_0 = hn::Set(df, tmp_max[3]); + const VF new_max_0 = hn::Set(df, old_max[3]); x_3_p0 = hn::Exp(df, hn::Sub(x_3_p0, new_max_0)); x_3_p1 = hn::Exp(df, hn::Sub(x_3_p1, new_max_0)); } @@ -520,8 +520,6 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap( old_d_vf = hn::LoadU(df4, old_d); VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max))); - hn::StoreU(new_max, df4, old_max); - VF4 x_sum = hn::Zero(df4); if constexpr (kNumQueries == 1) { x_sum = hn::Set(df4, hn::ReduceSum(df, x_0_p0) + hn::ReduceSum(df, x_0_p1)); @@ -539,12 +537,12 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap( const VF4 zero4 = hn::Zero(df4); const VF4 one_over_d = hn::MaskedDivOr(zero4, non_zero_mask, hn::Set(df4, 1.0f), old_d_vf); - float tmp_one_over_d[4]; + HWY_ALIGN float tmp_one_over_d[4]; hn::Store(one_over_d, df4, tmp_one_over_d); - hn::Store(old_d_vf, df4, old_d); + hn::BlendedStore(old_d_vf, changed_max, df4, old_d); scale = hn::Mul(scale, one_over_d); - hn::Store(scale, df4, scales); - if (hn::ExtractLane(old_d_vf, 0) > 0.0f) { + hn::BlendedStore(scale, changed_max, df4, scales); + if (hn::ExtractLane(old_d_vf, 0) > 0.0f && scales[0] != 1.0f) { const VF one_over_d_0 = hn::Set(df, tmp_one_over_d[0]); x_0_p0 = hn::Mul(x_0_p0, one_over_d_0); x_0_p1 = hn::Mul(x_0_p1, one_over_d_0); @@ -553,7 +551,7 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap( x_0_p1 = zero; } if constexpr (kNumQueries >= 2) { - if (hn::ExtractLane(old_d_vf, 1) > 0.0f) { + if (hn::ExtractLane(old_d_vf, 1) > 0.0f && scales[1] != 1.0f) { const VF one_over_d_1 = hn::Set(df, tmp_one_over_d[1]); x_1_p0 = hn::Mul(x_1_p0, one_over_d_1); x_1_p1 = hn::Mul(x_1_p1, one_over_d_1); @@ -563,7 +561,7 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap( } } if constexpr (kNumQueries >= 3) { - if (hn::ExtractLane(old_d_vf, 2) > 0.0f) { + if (hn::ExtractLane(old_d_vf, 2) > 0.0f && scales[2] != 1.0f) { const VF one_over_d_2 = hn::Set(df, tmp_one_over_d[2]); x_2_p0 = hn::Mul(x_2_p0, one_over_d_2); x_2_p1 = hn::Mul(x_2_p1, one_over_d_2); @@ -573,7 +571,7 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap( } } if constexpr (kNumQueries >= 4) { - if (hn::ExtractLane(old_d_vf, 3) > 0.0f) { + if (hn::ExtractLane(old_d_vf, 3) > 0.0f && scales[3] != 1.0f) { const VF one_over_d_3 = hn::Set(df, tmp_one_over_d[3]); x_3_p0 = hn::Mul(x_3_p0, one_over_d_3); x_3_p1 = hn::Mul(x_3_p1, one_over_d_3);