Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ cc_library(
":ops",
":tensor_stats",
":threading_context",
"@highway//:abort_header_only",
],
)

Expand Down Expand Up @@ -678,6 +679,7 @@ cc_library(
":attention",
":basics",
":configs",
":flash_structs",
":gemma_args",
":kv_cache",
":mat",
Expand Down
31 changes: 18 additions & 13 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,16 @@ struct AttentionActivations {
: batch_size * layer_config.heads,
allocator)),
vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)),
vit_K(MatFactory("K2", seq_len, layer_config.qkv_dim, allocator)),
vit_C(MatFactory("C2", batch_size, seq_len, allocator)),
vit_K_T(MatFactory(
"K2_T", hwy::RoundUpTo(seq_len, kMaxBF16PerVector),
layer_config.heads *
hwy::RoundUpTo(layer_config.qkv_dim, kMaxBF16PerVector),
allocator, MatPadding::kPacked)),
vit_V_T(MatFactory(
"V2_T", hwy::RoundUpTo(seq_len, kMaxBF16PerVector),
layer_config.heads *
hwy::RoundUpTo(layer_config.qkv_dim, kMaxBF16PerVector),
allocator, MatPadding::kPacked)),
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
config.model_dim, allocator)),
// att is only valid for AttentionImpl::kOld.
Expand Down Expand Up @@ -126,7 +134,6 @@ struct AttentionActivations {
q.AllocateAndAttachRowPtrs(row_ptrs);
q_bf.AllocateAndAttachRowPtrs(row_ptrs);
q_T.AllocateAndAttachRowPtrs(row_ptrs);
vit_C.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
}

Expand All @@ -136,8 +143,7 @@ struct AttentionActivations {
// q_T rows are always qkv_dim!

vit_Q.OverrideRows(batch_size);
// vit_K stays seq_len!
vit_C.OverrideRows(batch_size);
// vit_K_T and vit_V_T stay seq_len!

pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
Expand Down Expand Up @@ -167,8 +173,8 @@ struct AttentionActivations {
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.

MatStorageT<float> vit_Q;
MatStorageT<float> vit_K;
MatStorageT<float> vit_C;
MatStorageT<KV_t> vit_K_T;
MatStorageT<KV_t> vit_V_T;

MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
Expand Down Expand Up @@ -214,8 +220,8 @@ struct AttentionActivationsPtrs {
q_bf = activations.q_bf;
q_T = activations.q_T;
vit_Q = activations.vit_Q;
vit_K = activations.vit_K;
vit_C = activations.vit_C;
vit_K_T = activations.vit_K_T;
vit_V_T = activations.vit_V_T;
pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att;
att_out = activations.att_out;
Expand All @@ -233,8 +239,7 @@ struct AttentionActivationsPtrs {
// q_T rows are always qkv_dim!

vit_Q.OverrideRows(batch_size);
// vit_K stays seq_len!
vit_C.OverrideRows(batch_size);
// vit_K_T and vit_V_T stay seq_len!

pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
Expand Down Expand Up @@ -267,8 +272,8 @@ struct AttentionActivationsPtrs {
MatPtrT<BF16> q_T;

MatPtrT<float> vit_Q;
MatPtrT<float> vit_K;
MatPtrT<float> vit_C;
MatPtrT<KV_t> vit_K_T;
MatPtrT<KV_t> vit_V_T;

// Output of RMSNorm before attention, size batch_size x model_dim.
MatPtrT<float> pre_att_rms_out;
Expand Down
18 changes: 18 additions & 0 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2260,3 +2260,21 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();

#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(DispatchTileFlashAttention148);

void DispatchDispatchTileFlashAttention148(
Tile148Params& params, const MatPtrT<BF16>& q, const MatPtrT<KV_t>& k,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
size_t qkv_dim, ThreadingContext& ctx, const size_t worker,
AttentionImpl attention_impl) {
HWY_DYNAMIC_DISPATCH(DispatchTileFlashAttention148)(
params, q, k, v, layer_idx, activations, att_out, qkv_dim, ctx, worker,
attention_impl);
}

} // namespace gcpp
#endif
15 changes: 7 additions & 8 deletions gemma/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,6 @@ namespace gcpp {
const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
\
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
const BF16* HWY_RESTRICT q, \
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, \
const AttentionActivationsPtrs& activations, \
float* HWY_RESTRICT att_out, \
ThreadingContext& ctx, size_t worker); \
\
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \
\
Expand Down Expand Up @@ -83,6 +75,13 @@ HWY_VISIT_TARGETS(GEMMA_DECL_FLASH_ATTENTION)

#undef GEMMA_DECL_FLASH_ATTENTION

void DispatchDispatchTileFlashAttention148(
Tile148Params& params, const MatPtrT<BF16>& q, const MatPtrT<KV_t>& k,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
size_t qkv_dim, ThreadingContext& ctx, const size_t worker,
AttentionImpl attention_impl);

} // namespace gcpp

#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_
8 changes: 0 additions & 8 deletions gemma/tiled_attention_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,6 @@ void TestAttentionMultipleTokens() {
test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.vit_Q);
FillMatPtrT(test_env.activations->attention.vit_K);
FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max);
Expand Down Expand Up @@ -590,8 +588,6 @@ void TestAttentionMultipleTokensAttentionWindowSizeEdgeCase() {
test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.vit_Q);
FillMatPtrT(test_env.activations->attention.vit_K);
FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max);
Expand Down Expand Up @@ -763,8 +759,6 @@ void TestAttentionMultipleTokensBF16() {
test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.vit_Q);
FillMatPtrT(test_env.activations->attention.vit_K);
FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max);
Expand Down Expand Up @@ -807,8 +801,6 @@ void TestAttentionMultipleTokensInt8() {
test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.vit_Q);
FillMatPtrT(test_env.activations->attention.vit_K);
FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max);
Expand Down
Loading
Loading