diff --git a/BUILD.bazel b/BUILD.bazel index cfd7e572..0582fae7 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -555,6 +555,7 @@ cc_library( ":ops", ":tensor_stats", ":threading_context", + "@highway//:abort_header_only", ], ) @@ -678,6 +679,7 @@ cc_library( ":attention", ":basics", ":configs", + ":flash_structs", ":gemma_args", ":kv_cache", ":mat", diff --git a/gemma/activations.h b/gemma/activations.h index ba0ceaf3..c14b24e2 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -76,8 +76,16 @@ struct AttentionActivations { : batch_size * layer_config.heads, allocator)), vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)), - vit_K(MatFactory("K2", seq_len, layer_config.qkv_dim, allocator)), - vit_C(MatFactory("C2", batch_size, seq_len, allocator)), + vit_K_T(MatFactory( + "K2_T", hwy::RoundUpTo(seq_len, kMaxBF16PerVector), + layer_config.heads * + hwy::RoundUpTo(layer_config.qkv_dim, kMaxBF16PerVector), + allocator, MatPadding::kPacked)), + vit_V_T(MatFactory( + "V2_T", hwy::RoundUpTo(seq_len, kMaxBF16PerVector), + layer_config.heads * + hwy::RoundUpTo(layer_config.qkv_dim, kMaxBF16PerVector), + allocator, MatPadding::kPacked)), pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size, config.model_dim, allocator)), // att is only valid for AttentionImpl::kOld. @@ -126,7 +134,6 @@ struct AttentionActivations { q.AllocateAndAttachRowPtrs(row_ptrs); q_bf.AllocateAndAttachRowPtrs(row_ptrs); q_T.AllocateAndAttachRowPtrs(row_ptrs); - vit_C.AllocateAndAttachRowPtrs(row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs); } @@ -136,8 +143,7 @@ struct AttentionActivations { // q_T rows are always qkv_dim! vit_Q.OverrideRows(batch_size); - // vit_K stays seq_len! - vit_C.OverrideRows(batch_size); + // vit_K_T and vit_V_T stay seq_len! pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); @@ -167,8 +173,8 @@ struct AttentionActivations { MatStorageT q_T; // Transposed to maximize attention speed. MatStorageT vit_Q; - MatStorageT vit_K; - MatStorageT vit_C; + MatStorageT vit_K_T; + MatStorageT vit_V_T; MatStorageT pre_att_rms_out; MatStorageT att; // attention vector @@ -214,8 +220,8 @@ struct AttentionActivationsPtrs { q_bf = activations.q_bf; q_T = activations.q_T; vit_Q = activations.vit_Q; - vit_K = activations.vit_K; - vit_C = activations.vit_C; + vit_K_T = activations.vit_K_T; + vit_V_T = activations.vit_V_T; pre_att_rms_out = activations.pre_att_rms_out; att = activations.att; att_out = activations.att_out; @@ -233,8 +239,7 @@ struct AttentionActivationsPtrs { // q_T rows are always qkv_dim! vit_Q.OverrideRows(batch_size); - // vit_K stays seq_len! - vit_C.OverrideRows(batch_size); + // vit_K_T and vit_V_T stay seq_len! pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); @@ -267,8 +272,8 @@ struct AttentionActivationsPtrs { MatPtrT q_T; MatPtrT vit_Q; - MatPtrT vit_K; - MatPtrT vit_C; + MatPtrT vit_K_T; + MatPtrT vit_V_T; // Output of RMSNorm before attention, size batch_size x model_dim. MatPtrT pre_att_rms_out; diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 4f4336e8..5922d8cc 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -2260,3 +2260,21 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace gcpp { +HWY_EXPORT(DispatchTileFlashAttention148); + +void DispatchDispatchTileFlashAttention148( + Tile148Params& params, const MatPtrT& q, const MatPtrT& k, + const MatPtrT& v, const size_t layer_idx, + const AttentionActivationsPtrs& activations, MatPtrT& att_out, + size_t qkv_dim, ThreadingContext& ctx, const size_t worker, + AttentionImpl attention_impl) { + HWY_DYNAMIC_DISPATCH(DispatchTileFlashAttention148)( + params, q, k, v, layer_idx, activations, att_out, qkv_dim, ctx, worker, + attention_impl); +} + +} // namespace gcpp +#endif diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index 7d06af95..1f27dfed 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -42,14 +42,6 @@ namespace gcpp { const MatPtr& query_norm_scale, size_t layer_idx, \ const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ \ - void SingleFlashAttention(size_t start_pos, size_t last_pos, \ - const BF16* HWY_RESTRICT q, \ - const MatPtrT& k, const MatPtrT& v, \ - size_t layer_idx, \ - const AttentionActivationsPtrs& activations, \ - float* HWY_RESTRICT att_out, \ - ThreadingContext& ctx, size_t worker); \ - \ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ size_t total_tasks, size_t target_parallelism); \ \ @@ -83,6 +75,13 @@ HWY_VISIT_TARGETS(GEMMA_DECL_FLASH_ATTENTION) #undef GEMMA_DECL_FLASH_ATTENTION +void DispatchDispatchTileFlashAttention148( + Tile148Params& params, const MatPtrT& q, const MatPtrT& k, + const MatPtrT& v, const size_t layer_idx, + const AttentionActivationsPtrs& activations, MatPtrT& att_out, + size_t qkv_dim, ThreadingContext& ctx, const size_t worker, + AttentionImpl attention_impl); + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_ diff --git a/gemma/tiled_attention_test.cc b/gemma/tiled_attention_test.cc index 84354ca2..3efdd695 100644 --- a/gemma/tiled_attention_test.cc +++ b/gemma/tiled_attention_test.cc @@ -544,8 +544,6 @@ void TestAttentionMultipleTokens() { test_env.SetupWeights(); FillMatPtrT(test_env.activations->attention.pre_att_rms_out); FillMatPtrT(test_env.activations->attention.q); - FillMatPtrT(test_env.activations->attention.vit_Q); - FillMatPtrT(test_env.activations->attention.vit_K); FillMatPtrT(test_env.activations->attention.att); FillMatPtrT(test_env.activations->attention.att_out); FillMatPtrT(test_env.activations->attention.softmax_max); @@ -590,8 +588,6 @@ void TestAttentionMultipleTokensAttentionWindowSizeEdgeCase() { test_env.SetupWeights(); FillMatPtrT(test_env.activations->attention.pre_att_rms_out); FillMatPtrT(test_env.activations->attention.q); - FillMatPtrT(test_env.activations->attention.vit_Q); - FillMatPtrT(test_env.activations->attention.vit_K); FillMatPtrT(test_env.activations->attention.att); FillMatPtrT(test_env.activations->attention.att_out); FillMatPtrT(test_env.activations->attention.softmax_max); @@ -763,8 +759,6 @@ void TestAttentionMultipleTokensBF16() { test_env.SetupWeights(); FillMatPtrT(test_env.activations->attention.pre_att_rms_out); FillMatPtrT(test_env.activations->attention.q); - FillMatPtrT(test_env.activations->attention.vit_Q); - FillMatPtrT(test_env.activations->attention.vit_K); FillMatPtrT(test_env.activations->attention.att); FillMatPtrT(test_env.activations->attention.att_out); FillMatPtrT(test_env.activations->attention.softmax_max); @@ -807,8 +801,6 @@ void TestAttentionMultipleTokensInt8() { test_env.SetupWeights(); FillMatPtrT(test_env.activations->attention.pre_att_rms_out); FillMatPtrT(test_env.activations->attention.q); - FillMatPtrT(test_env.activations->attention.vit_Q); - FillMatPtrT(test_env.activations->attention.vit_K); FillMatPtrT(test_env.activations->attention.att); FillMatPtrT(test_env.activations->attention.att_out); FillMatPtrT(test_env.activations->attention.softmax_max); diff --git a/gemma/vit.cc b/gemma/vit.cc index be14b125..ecc0062e 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -20,6 +20,7 @@ #include #include "compression/types.h" // GEMMA_DISABLED_TARGETS +#include "gemma/flash_structs.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS @@ -41,6 +42,8 @@ #include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/highway.h" // After highway.h +#include "gemma/attention.h" +#include "gemma/flash_attention.h" #include "gemma/gemma-inl.h" #include "ops/ops-inl.h" @@ -68,107 +71,194 @@ class VitAttention { layer_.vit.qkv_einsum_b.PackedScale1(), env_, qkv); } - // TODO(philculliton): transition fully to MatMul. - HWY_NOINLINE void DotSoftmaxWeightedSumMatrix() { - const size_t qkv_dim = layer_config_.qkv_dim; - const size_t heads = layer_config_.heads; - HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); - const size_t seq_len = - static_cast(activations_.attention.div_seq_len.GetDivisor()); - const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); - PROFILER_ZONE("Gen.VitAttention.DotSoftmaxMatrix"); - - MatPtrT& Q = activations_.attention.vit_Q; - MatPtrT& K = activations_.attention.vit_K; - MatPtrT& C = activations_.attention.vit_C; - - // Initialize att_out to zero prior to head loop. - ZeroInit(activations_.attention.att_out); - - for (size_t head = 0; head < heads; ++head) { - pool_.Run(0, num_tokens_, caller1_, - [&](uint64_t task, size_t worker) HWY_ATTR { - const size_t token = task; - float* HWY_RESTRICT q = - activations_.attention.q.Row(token) + head * 3 * qkv_dim; - // TODO: shift to MatMul with A.scale once MatMul is confirmed - // working - MulByConst(query_scale, q, qkv_dim); - hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); + // Applies the query scale to the query and converts to QType. + template + void ScaleQuery(const MatPtrT& qkv, const size_t num_tokens, + const size_t heads, const size_t qkv_dim, + const float query_scale, MatPtrT& q_output) { + ParallelFor(Parallelism::kWithinCluster, heads, env_.ctx, + /*cluster_idx=*/0, Callers::kFlashAttention, + [&](size_t head, size_t worker) { + size_t q_offset = head * qkv_dim; + for (size_t token = 0; token < num_tokens; ++token) { + const float* HWY_RESTRICT src_q = + qkv.Row(token) + q_offset * 3; + QType* HWY_RESTRICT dst_q = q_output.Row(token) + q_offset; + for (size_t i = 0; i < qkv_dim; ++i) { + dst_q[i] = hwy::ConvertScalarTo( + hwy::ConvertScalarTo(src_q[i]) * query_scale); + } + } }); + } - pool_.Run( - 0, seq_len, caller2_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t seq_idx = task; - float* HWY_RESTRICT k = activations_.attention.q.Row(seq_idx) + - head * 3 * qkv_dim + qkv_dim; - hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float)); - }); - - // this produces C, a (num_tokens_, seq_len) matrix of dot products - CallMatMul(Q, K, nullptr, env_, C); - - pool_.Run(0, num_tokens_, caller3_, - [&](uint64_t task, size_t worker) - HWY_ATTR { Softmax(C.RowSpan(task), env_.ctx, worker); }); - - pool_.Run( - 0, num_tokens_, caller4_, [&](uint64_t task, size_t worker) HWY_ATTR { - size_t token = task; - float* HWY_RESTRICT att_out = - activations_.attention.att_out.Row(token) + head * qkv_dim; - for (size_t i = 0; i < seq_len; ++i) { - float* HWY_RESTRICT v = activations_.attention.q.Row(i) + - head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim); + // Transposes K and V and converts to KVType. + template + void TransposeKAndV(const MatPtrT& qkv, const size_t num_tokens, + const size_t heads, const size_t qkv_dim, + MatPtrT& k_output, MatPtrT& v_output) { + using DF = hn::ScalableTag; + const DF df; + const size_t kNF = hn::Lanes(df); + const size_t kNumTokensH = hwy::DivCeil(num_tokens, 2 * kNF); + const size_t kRoundedKVDim = hwy::RoundUpTo(qkv_dim, 2 * kNF); + ParallelFor( + Parallelism::kWithinCluster, heads, env_.ctx, + /*cluster_idx=*/0, Callers::kFlashAttention, + [&](size_t head, size_t worker) { + const size_t qkv_offset = head * 3 * qkv_dim; + const size_t k_or_v_offset = head * 2 * kNF * kRoundedKVDim; + for (size_t token_h = 0; token_h < kNumTokensH; ++token_h) { + KVType* HWY_RESTRICT dst_k = k_output.Row(token_h); + KVType* HWY_RESTRICT dst_v = v_output.Row(token_h); + size_t dst_k_index = k_or_v_offset; + for (size_t q = 0; q < qkv_dim; q += 2) { + for (size_t token_l = 0; token_l < 2 * kNF; + ++token_l, dst_k_index += 2) { + const QKVType* HWY_RESTRICT src_k = + qkv.Row(token_h * 2 * kNF + token_l) + qkv_offset + qkv_dim; + dst_k[dst_k_index] = hwy::ConvertScalarTo(src_k[q]); + dst_k[dst_k_index + 1] = + hwy::ConvertScalarTo(src_k[q + 1]); + } + } + size_t dst_v_index = k_or_v_offset; + for (size_t q = 0; q < qkv_dim; q += 2 * kNF) { + for (size_t token_l = 0; token_l < 2 * kNF; ++token_l) { + const QKVType* HWY_RESTRICT src_v = + qkv.Row(token_h * 2 * kNF + token_l) + qkv_offset + + qkv_dim * 2; + if (q + 2 * kNF <= qkv_dim) { + for (size_t q_l = 0; q_l < 2 * kNF; ++q_l) { + dst_v[dst_v_index++] = + hwy::ConvertScalarTo(src_v[q + q_l]); + } + } else { + for (size_t q_l = 0; q_l < qkv_dim - q; ++q_l) { + dst_v[dst_v_index++] = + hwy::ConvertScalarTo(src_v[q + q_l]); + } + } + } + } + // Zero out the padding area. + // In the loops above, the dst_k loop has written 2kNF x 2 + // consecutive elements for each q +=2, and the dst_v loop has + // written 2kNF x 2kNF consecutive elements for each q += 2 * kNF. + // Both of them therefore write 2kNF elements for each increment of + // q, so we can combine both into a single loop for the padding. + // This could be further simplified by writing a zero vector. + for (size_t q = qkv_dim; q < kRoundedKVDim; ++q) { + for (size_t token_l = 0; token_l < 2 * kNF; ++token_l) { + dst_k[dst_k_index++] = hwy::ConvertScalarTo(0.0f); + dst_v[dst_v_index++] = hwy::ConvertScalarTo(0.0f); + } } - }); + } + }); + } + + // Computes the flash attention parameters. This is mostly about deciding on + // the tile sizes and filling the param structs with the correct offsets. + template + void ComputeParams(const uint32_t num_tokens, const size_t seq_len, + const size_t heads, const uint32_t qkv_dim, + const MatPtrT& q, const MatPtrT& k, + const MatPtrT& v, const MatPtrT& att_out, + std::vector& flash_params) { + flash_params.clear(); + for (uint32_t head = 0; head < heads; ++head) { + uint32_t token = 0; + while (token + k8xNFVTileSize <= num_tokens) { + flash_params.push_back(Tile148Params{ + .v_tile_size = k8xNFVTileSize, + .qi_index = token, + .kv_head = head, + }); + token += k8xNFVTileSize; + } + if (token + k4xNFVTileSize <= num_tokens) { + flash_params.push_back(Tile148Params{ + .v_tile_size = k4xNFVTileSize, + .qi_index = token, + .kv_head = head, + }); + token += k4xNFVTileSize; + } + while (token < num_tokens) { + flash_params.push_back(Tile148Params{ + .v_tile_size = 1, + .qi_index = token, + .kv_head = head, + }); + token += 1; + } + } + for (auto& param : flash_params) { + param.min_start_pos = 0; + param.max_last_pos = num_tokens - 1; + for (size_t i = 0; i < param.v_tile_size; ++i) { + param.q_offsets[i] = + q.Row(param.qi_index + i) + param.kv_head * qkv_dim - q.Row(0); + param.out_offsets[i] = att_out.Row(param.qi_index + i) + + param.kv_head * qkv_dim - att_out.Row(0); + param.start_pos[i] = 0; + param.last_pos[i] = num_tokens - 1; + } } } - HWY_NOINLINE void DotSoftmaxWeightedSum() { + // Runs the flash attention algorithm on Q, K, V. + HWY_NOINLINE void FlashAttention() { + GCPP_ZONE(env_.ctx, 0, Zones::kVitFlashAttentionInclusive); const size_t qkv_dim = layer_config_.qkv_dim; const size_t heads = layer_config_.heads; HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); - const size_t seq_len = - static_cast(activations_.attention.div_seq_len.GetDivisor()); + const size_t kNF = FloatsPerVector(); + const size_t kRoundedKVDim = hwy::RoundUpTo(qkv_dim, 2 * kNF); + auto& attn = activations_.attention; + const size_t seq_len = static_cast(attn.div_seq_len.GetDivisor()); + if (attn.vit_K_T.Rows() >= seq_len) { + attn.vit_K_T.ReshapePackedRowsToCols(2 * kNF); + attn.vit_V_T.ReshapePackedRowsToCols(2 * kNF); + } const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); - PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); - - // Compute Q.K, softmax, and weighted V. - pool_.Run(0, layer_config_.heads * num_tokens_, caller1_, - [&](uint64_t task, size_t worker) HWY_ATTR { - const size_t head = task % layer_config_.heads; - const size_t token = task / layer_config_.heads; - // Compute Q.K scores, which are "logits" stored in head_att. - float* HWY_RESTRICT q = - activations_.attention.q.Row(token) + head * 3 * qkv_dim; - MulByConst(query_scale, q, qkv_dim); - float* HWY_RESTRICT head_att = - activations_.attention.att.Row(token) + head * seq_len; - for (size_t i = 0; i < seq_len; ++i) { - float* HWY_RESTRICT k = activations_.attention.q.Row(i) + - head * 3 * qkv_dim + qkv_dim; - head_att[i] = Dot(q, k, qkv_dim); // score = q.k - } - // SoftMax yields "probabilities" in head_att. - Softmax(Logits(head_att, seq_len), env_.ctx, worker); - // Compute weighted sum of v into att_out. - float* HWY_RESTRICT att_out = - activations_.attention.att_out.Row(token) + head * qkv_dim; - hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); - for (size_t i = 0; i < seq_len; ++i) { - float* HWY_RESTRICT v = activations_.attention.q.Row(i) + - head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(head_att[i], v, att_out, qkv_dim); - } - }); + ScaleQuery(attn.q, num_tokens_, heads, qkv_dim, query_scale, attn.q_bf); + TransposeKAndV(attn.q, num_tokens_, heads, qkv_dim, attn.vit_K_T, + attn.vit_V_T); + ComputeParams(num_tokens_, seq_len, heads, qkv_dim, attn.q_bf, attn.vit_K_T, + attn.vit_V_T, attn.att_out, attn.flash_params); + size_t num_tasks = attn.flash_params.size(); + + // For each param, compute fused flash Q.K, softmax and weighted V. + const auto func = [&, &ctx = env_.ctx](const size_t task, + size_t worker) HWY_ATTR { + GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention); + auto& param = attn.flash_params[task]; + MatPtrT kT("k_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF), + kRoundedKVDim * 2 * kNF)); + kT.SetPtr(attn.vit_K_T.Row(0) + param.kv_head * kRoundedKVDim * 2 * kNF, + attn.vit_K_T.Stride()); + MatPtrT vT("v_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF), + kRoundedKVDim * 2 * kNF)); + vT.SetPtr(attn.vit_V_T.Row(0) + param.kv_head * kRoundedKVDim * 2 * kNF, + attn.vit_V_T.Stride()); + DispatchDispatchTileFlashAttention148( + param, attn.q_bf, kT, vT, /*layer_idx=*/0, attn, attn.att_out, + qkv_dim, ctx, worker, /*attention_impl=*/AttentionImpl::kFlash); + }; + + { + PROFILER_ZONE("Gen.VitFlashAttention.ForkJoin"); + // Full parallelism is helpful, SmallParallelFor is insufficient. + HierarchicalParallelFor(num_tasks, env_.ctx, Callers::kFlashAttention, + func); + } } // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and // head_dim (`qkv_dim`) into output (`att_sums`). HWY_NOINLINE void SumHeads() { - PROFILER_ZONE("Gen.VitAttention.SumHeads"); auto* bias = layer_.vit.attn_out_b.PackedScale1(); // att_weights and att_out are concatenated heads, each of length // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] @@ -193,11 +283,7 @@ class VitAttention { HWY_INLINE void operator()() { ComputeQKV(); - if (activations_.attention.config.wrapping == PromptWrapping::GEMMA_VLM) { - DotSoftmaxWeightedSumMatrix(); - } else { - DotSoftmaxWeightedSum(); - } + FlashAttention(); SumHeads(); } diff --git a/ops/ops-inl.h b/ops/ops-inl.h index e035f1b9..c678d13f 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -669,10 +669,10 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT4Mem( size_t i = 0; while (i + NF * 2 <= size) { VF out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b; - out0a = hn::Load(df, out + i + out_offsets[0]); - out1a = hn::Load(df, out + i + out_offsets[1]); - out2a = hn::Load(df, out + i + out_offsets[2]); - out3a = hn::Load(df, out + i + out_offsets[3]); + out0a = hn::LoadU(df, out + i + out_offsets[0]); + out1a = hn::LoadU(df, out + i + out_offsets[1]); + out2a = hn::LoadU(df, out + i + out_offsets[2]); + out3a = hn::LoadU(df, out + i + out_offsets[3]); VF scale0 = hn::Set(df, scales[0]); VF scale1 = hn::Set(df, scales[1]); VF scale2 = hn::Set(df, scales[2]); @@ -681,28 +681,70 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT4Mem( out1a = hn::Mul(out1a, scale1); out2a = hn::Mul(out2a, scale2); out3a = hn::Mul(out3a, scale3); - out0b = hn::Load(df, out + i + NF + out_offsets[0]); - out1b = hn::Load(df, out + i + NF + out_offsets[1]); - out2b = hn::Load(df, out + i + NF + out_offsets[2]); - out3b = hn::Load(df, out + i + NF + out_offsets[3]); + out0b = hn::LoadU(df, out + i + NF + out_offsets[0]); + out1b = hn::LoadU(df, out + i + NF + out_offsets[1]); + out2b = hn::LoadU(df, out + i + NF + out_offsets[2]); + out3b = hn::LoadU(df, out + i + NF + out_offsets[3]); out0b = hn::Mul(out0b, scale0); out1b = hn::Mul(out1b, scale1); out2b = hn::Mul(out2b, scale2); out3b = hn::Mul(out3b, scale3); MulAddNLanesVT4(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b); - hn::Store(out0a, df, out + i + out_offsets[0]); - hn::Store(out1a, df, out + i + out_offsets[1]); - hn::Store(out2a, df, out + i + out_offsets[2]); - hn::Store(out3a, df, out + i + out_offsets[3]); - hn::Store(out0b, df, out + i + NF + out_offsets[0]); - hn::Store(out1b, df, out + i + NF + out_offsets[1]); - hn::Store(out2b, df, out + i + NF + out_offsets[2]); - hn::Store(out3b, df, out + i + NF + out_offsets[3]); + hn::StoreU(out0a, df, out + i + out_offsets[0]); + hn::StoreU(out1a, df, out + i + out_offsets[1]); + hn::StoreU(out2a, df, out + i + out_offsets[2]); + hn::StoreU(out3a, df, out + i + out_offsets[3]); + hn::StoreU(out0b, df, out + i + NF + out_offsets[0]); + hn::StoreU(out1b, df, out + i + NF + out_offsets[1]); + hn::StoreU(out2b, df, out + i + NF + out_offsets[2]); + hn::StoreU(out3b, df, out + i + NF + out_offsets[3]); i += NF * 2; v_bf += 4 * NF * NF; } - HWY_DASSERT(size == i); + if (i < size) { + VF out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b; + out0a = hn::LoadN(df, out + i + out_offsets[0], size - i); + out1a = hn::LoadN(df, out + i + out_offsets[1], size - i); + out2a = hn::LoadN(df, out + i + out_offsets[2], size - i); + out3a = hn::LoadN(df, out + i + out_offsets[3], size - i); + VF scale0 = hn::Set(df, scales[0]); + VF scale1 = hn::Set(df, scales[1]); + VF scale2 = hn::Set(df, scales[2]); + VF scale3 = hn::Set(df, scales[3]); + out0a = hn::Mul(out0a, scale0); + out1a = hn::Mul(out1a, scale1); + out2a = hn::Mul(out2a, scale2); + out3a = hn::Mul(out3a, scale3); + if (i + NF < size) { + out0b = hn::LoadN(df, out + i + NF + out_offsets[0], size - i - NF); + out1b = hn::LoadN(df, out + i + NF + out_offsets[1], size - i - NF); + out2b = hn::LoadN(df, out + i + NF + out_offsets[2], size - i - NF); + out3b = hn::LoadN(df, out + i + NF + out_offsets[3], size - i - NF); + out0b = hn::Mul(out0b, scale0); + out1b = hn::Mul(out1b, scale1); + out2b = hn::Mul(out2b, scale2); + out3b = hn::Mul(out3b, scale3); + } else { + out0b = hn::Zero(df); + out1b = hn::Zero(df); + out2b = hn::Zero(df); + out3b = hn::Zero(df); + } + // Note that v_bf is always padded, so we can always load 2 * NF elements. + MulAddNLanesVT4(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out1a, + out2a, out3a, out0b, out1b, out2b, out3b); + hn::StoreN(out0a, df, out + i + out_offsets[0], size - i); + hn::StoreN(out1a, df, out + i + out_offsets[1], size - i); + hn::StoreN(out2a, df, out + i + out_offsets[2], size - i); + hn::StoreN(out3a, df, out + i + out_offsets[3], size - i); + if (i + NF < size) { + hn::StoreN(out0b, df, out + i + NF + out_offsets[0], size - i - NF); + hn::StoreN(out1b, df, out + i + NF + out_offsets[1], size - i - NF); + hn::StoreN(out2b, df, out + i + NF + out_offsets[2], size - i - NF); + hn::StoreN(out3b, df, out + i + NF + out_offsets[3], size - i - NF); + } + } } template > @@ -743,26 +785,33 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT1Mem( size_t i = 0; while (i + NF * 2 <= size) { VF out0a, out0b; - out0a = hn::Load(df, out + i + out_offsets[0]); + out0a = hn::LoadU(df, out + i + out_offsets[0]); VF scale0 = hn::Set(df, scales[0]); out0a = hn::Mul(out0a, scale0); - out0b = hn::Load(df, out + i + NF + out_offsets[0]); + out0b = hn::LoadU(df, out + i + NF + out_offsets[0]); out0b = hn::Mul(out0b, scale0); MulAddNLanesVT1(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out0b); - hn::Store(out0a, df, out + i + out_offsets[0]); - hn::Store(out0b, df, out + i + NF + out_offsets[0]); + hn::StoreU(out0a, df, out + i + out_offsets[0]); + hn::StoreU(out0b, df, out + i + NF + out_offsets[0]); i += NF * 2; v_bf += 4 * NF * NF; } - while (i < size) { - float sum = out[i + out_offsets[0]] * scales[0]; - const BF16* HWY_RESTRICT v_local = v_bf; - for (size_t lane = 0; lane < HWY_MIN(num_lanes, 2 * NF); - ++lane, v_local += 2 * NF) { - sum += hwy::ConvertScalarTo(*v_local) * c_mem[lane]; + if (i < size) { + VF out0a, out0b; + out0a = hn::LoadN(df, out + i + out_offsets[0], size - i); + VF scale0 = hn::Set(df, scales[0]); + out0a = hn::Mul(out0a, scale0); + if (i + NF < size) { + out0b = hn::LoadN(df, out + i + NF + out_offsets[0], size - i - NF); + out0b = hn::Mul(out0b, scale0); + } else { + out0b = hn::Zero(df); + } + MulAddNLanesVT1(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out0b); + hn::StoreN(out0a, df, out + i + out_offsets[0], size - i); + if (i + NF < size) { + hn::StoreN(out0b, df, out + i + NF + out_offsets[0], size - i - NF); } - ++i; - ++v_bf; } } diff --git a/util/zones.cc b/util/zones.cc index 78a5fd8b..aec4bbd0 100644 --- a/util/zones.cc +++ b/util/zones.cc @@ -15,6 +15,8 @@ const char* ZoneName(Zones zone) { return "FlashAttention.FlashAttention"; case Zones::kFlashAttentionInclusive: return "FlashAttention.Inclusive"; + case Zones::kVitFlashAttentionInclusive: + return "Vit.FlashAttention.Inclusive"; case Zones::kFlashAttentionRmsNormAndPositionalEncoding: return "FlashAttention.RMSNormAndPositionalEncoding"; case Zones::kFlashAttentionTileFlashAttention1: @@ -106,6 +108,7 @@ const char* ZoneName(Zones zone) { hwy::ProfilerFlags ZoneFlags(Zones zone) { switch (zone) { case Zones::kFlashAttentionInclusive: + case Zones::kVitFlashAttentionInclusive: case Zones::kGenAttention: case Zones::kGenAttentionComputeQKV: case Zones::kGenAttentionDotSoftmaxWeightedSumInclusive: diff --git a/util/zones.h b/util/zones.h index 6f1a68c3..64b859d2 100644 --- a/util/zones.h +++ b/util/zones.h @@ -13,6 +13,7 @@ namespace gcpp { enum class Zones { // Keep sorted kFlashAttentionFlashAttention, kFlashAttentionInclusive, + kVitFlashAttentionInclusive, kFlashAttentionRmsNormAndPositionalEncoding, kFlashAttentionTileFlashAttention1, kFlashAttentionTileFlashAttention4,