From 1dcbe86a0c56b82a6014a04d045e911ea9a58179 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 31 Dec 2022 12:29:52 +0200 Subject: [PATCH] gpt-2 : experimenting with attention mask --- examples/gpt-2/main.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/gpt-2/main.cpp b/examples/gpt-2/main.cpp index 333d93b..6507ec2 100644 --- a/examples/gpt-2/main.cpp +++ b/examples/gpt-2/main.cpp @@ -496,6 +496,7 @@ bool gpt2_eval( ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head)) ); +#if 0 // KQ_masked = mask_past(KQ_scaled) // [n_past + N, N, 12] struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); @@ -503,6 +504,15 @@ bool gpt2_eval( // KQ = soft_max(KQ_masked) // [n_past + N, N, 12] struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); +#else + // KQ_masked = mask_past(KQ_scaled) + // [n_past + N, N, 12] + //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + // [n_past + N, N, 12] + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled); +#endif // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() // [n_past + N, 64, 12]