whisper : do not use F16 tensors when in F32 mode (#369)

pull/379/head
Georgi Gerganov 2 years ago
parent b3c865083e
commit ad2a4ffa03
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -412,6 +412,8 @@ struct whisper_context {
std::vector<uint8_t> buf_compute; std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_compute_layer; std::vector<uint8_t> buf_compute_layer;
ggml_type wtype; // weight type (FP32 or FP16)
whisper_model model; whisper_model model;
whisper_vocab vocab; whisper_vocab vocab;
@ -629,7 +631,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// for the big tensors, we have the option to store the data in 16-bit floats // for the big tensors, we have the option to store the data in 16-bit floats
// in order to save memory and also to speed up the computation // in order to save memory and also to speed up the computation
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
const ggml_type wtype = wctx.wtype;
size_t ctx_size = 0; size_t ctx_size = 0;
@ -650,7 +654,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// encoder // encoder
{ {
// TODO: F16 .. maybe not?
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe; ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
@ -665,7 +668,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// decoder // decoder
{ {
// TODO: F16 .. maybe not?
ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe; ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te; ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
@ -982,8 +984,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
const int n_mem = n_text_layer*n_text_ctx; const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem; const int n_elements = n_text_state*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
} }
// key/value memory for the cross-attention layer // key/value memory for the cross-attention layer
@ -993,8 +995,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
const int n_mem = n_text_layer*n_audio_ctx; const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem; const int n_elements = n_text_state*n_mem;
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
} }
const size_t memory_size = const size_t memory_size =
@ -1240,14 +1242,14 @@ static bool whisper_encode(
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
Qcur, Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3); 0, 2, 1, 3);
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
Kcur, Kcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3); 0, 2, 1, 3);
struct ggml_tensor * V = struct ggml_tensor * V =
@ -1257,7 +1259,7 @@ static bool whisper_encode(
Vcur, Vcur,
n_state/n_head, n_head, n_ctx), n_state/n_head, n_head, n_ctx),
1, 2, 0, 3), 1, 2, 0, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head) ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head)
); );
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false); struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
@ -1273,7 +1275,7 @@ static bool whisper_encode(
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
Kcur, Kcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3); 0, 2, 1, 3);
// K * Q // K * Q
@ -1291,7 +1293,7 @@ static bool whisper_encode(
// ggml_permute(ctxL, // ggml_permute(ctxL,
// ggml_cpy(ctxL, // ggml_cpy(ctxL,
// Vcur, // Vcur,
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), // ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
// 1, 2, 0, 3); // 1, 2, 0, 3);
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
@ -1303,7 +1305,7 @@ static bool whisper_encode(
Vcur, Vcur,
n_state/n_head, n_head, n_ctx), n_state/n_head, n_head, n_ctx),
0, 2, 1, 3), 0, 2, 1, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head) ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head)
); );
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
@ -1348,7 +1350,7 @@ static bool whisper_encode(
#ifdef USE_FLASH_FF #ifdef USE_FLASH_FF
cur = ggml_flash_ff(ctxL, cur = ggml_flash_ff(ctxL,
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)), ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
#else #else
// fully connected // fully connected
@ -3156,7 +3158,7 @@ int whisper_full_parallel(
// separate key + value memory for each processor // separate key + value memory for each processor
{ {
auto & ctx = model.ctx_mem; auto & mctx = model.ctx_mem;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
@ -3169,8 +3171,8 @@ int whisper_full_parallel(
const int n_mem = n_text_layer*n_text_ctx; const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem; const int n_elements = n_text_state*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
} }
// key/value memory for the cross-attention layer // key/value memory for the cross-attention layer
@ -3180,8 +3182,8 @@ int whisper_full_parallel(
const int n_mem = n_text_layer*n_audio_ctx; const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem; const int n_elements = n_text_state*n_mem;
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
} }
} }
} }

Loading…
Cancel
Save