|
|
|
@ -252,12 +252,34 @@ static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
|
|
|
|
|
{ MODEL_LARGE, 9ull*MB },
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static const std::map<e_model, size_t> MEM_REQ_MODEL = {
|
|
|
|
|
{ MODEL_TINY, 74ull*MB },
|
|
|
|
|
{ MODEL_BASE, 142ull*MB },
|
|
|
|
|
{ MODEL_SMALL, 466ull*MB },
|
|
|
|
|
{ MODEL_MEDIUM, 1464ull*MB },
|
|
|
|
|
{ MODEL_LARGE, 2952ull*MB },
|
|
|
|
|
static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
|
|
|
|
|
{ GGML_TYPE_F16,
|
|
|
|
|
{
|
|
|
|
|
{ MODEL_TINY, 74ull*MB },
|
|
|
|
|
{ MODEL_BASE, 142ull*MB },
|
|
|
|
|
{ MODEL_SMALL, 466ull*MB },
|
|
|
|
|
{ MODEL_MEDIUM, 1464ull*MB },
|
|
|
|
|
{ MODEL_LARGE, 2952ull*MB },
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
{ GGML_TYPE_Q4_0,
|
|
|
|
|
{
|
|
|
|
|
{ MODEL_TINY, 26ull*MB },
|
|
|
|
|
{ MODEL_BASE, 50ull*MB },
|
|
|
|
|
{ MODEL_SMALL, 154ull*MB },
|
|
|
|
|
{ MODEL_MEDIUM, 470ull*MB },
|
|
|
|
|
{ MODEL_LARGE, 940ull*MB },
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
{ GGML_TYPE_Q4_1,
|
|
|
|
|
{
|
|
|
|
|
{ MODEL_TINY, 31ull*MB },
|
|
|
|
|
{ MODEL_BASE, 57ull*MB },
|
|
|
|
|
{ MODEL_SMALL, 181ull*MB },
|
|
|
|
|
{ MODEL_MEDIUM, 559ull*MB },
|
|
|
|
|
{ MODEL_LARGE, 1122ull*MB },
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
|
|
|
|
@ -681,7 +703,7 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
|
|
|
|
|
const ggml_type wtype = cache.k->type;
|
|
|
|
|
WHISPER_ASSERT(wtype == cache.v->type);
|
|
|
|
|
|
|
|
|
|
WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype));
|
|
|
|
|
WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_sizef(wtype));
|
|
|
|
|
|
|
|
|
|
struct ggml_init_params params;
|
|
|
|
|
params.mem_size = cache.buf.size();
|
|
|
|
@ -776,12 +798,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
|
model.type = e_model::MODEL_LARGE;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// for the big tensors, we have the option to store the data in 16-bit floats
|
|
|
|
|
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
|
|
|
|
|
// in order to save memory and also to speed up the computation
|
|
|
|
|
wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
|
|
|
|
wctx.wtype = GGML_TYPE_COUNT;
|
|
|
|
|
switch (model.hparams.f16) {
|
|
|
|
|
case 0: wctx.wtype = GGML_TYPE_F32; break;
|
|
|
|
|
case 1: wctx.wtype = GGML_TYPE_F16; break;
|
|
|
|
|
case 2: wctx.wtype = GGML_TYPE_Q4_0; break;
|
|
|
|
|
case 3: wctx.wtype = GGML_TYPE_Q4_1; break;
|
|
|
|
|
default:
|
|
|
|
|
{
|
|
|
|
|
fprintf(stderr, "%s: invalid model (bad f16 value %d)\n", __func__, model.hparams.f16);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const size_t scale = model.hparams.f16 ? 1 : 2;
|
|
|
|
|
|
|
|
|
|
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
|
|
|
|
|
|
|
|
|
|
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
|
|
|
|
fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
|
|
|
|
|
fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
|
|
|
|
@ -792,7 +827,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
|
fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
|
|
|
|
|
fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
|
|
|
|
|
fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
|
|
|
|
|
fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
|
|
|
|
|
fprintf(stderr, "%s: ftype = %s\n", __func__, ftype_str[model.hparams.f16]);
|
|
|
|
|
fprintf(stderr, "%s: type = %d\n", __func__, model.type);
|
|
|
|
|
|
|
|
|
|
// print memory requirements
|
|
|
|
@ -803,7 +838,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
|
MEM_REQ_SCRATCH1.at (model.type) +
|
|
|
|
|
MEM_REQ_SCRATCH2.at (model.type) +
|
|
|
|
|
MEM_REQ_SCRATCH3.at (model.type) +
|
|
|
|
|
scale*MEM_REQ_MODEL.at (model.type) +
|
|
|
|
|
scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) +
|
|
|
|
|
scale*MEM_REQ_KV_CROSS.at(model.type) +
|
|
|
|
|
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
|
|
|
|
|
|
|
|
|
@ -819,9 +854,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
|
// always have at least one decoder
|
|
|
|
|
|
|
|
|
|
wctx.model.buf = new std::vector<uint8_t>();
|
|
|
|
|
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
|
|
|
|
|
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type));
|
|
|
|
|
|
|
|
|
|
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
|
|
|
|
|
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, GGML_TYPE_F16, model.hparams.n_text_ctx)) {
|
|
|
|
|
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
@ -831,7 +866,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
|
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) {
|
|
|
|
|
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, GGML_TYPE_F16, model.hparams.n_audio_ctx)) {
|
|
|
|
|
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
@ -963,92 +998,92 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
|
|
|
|
|
|
// encoder
|
|
|
|
|
{
|
|
|
|
|
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
|
|
|
|
|
ctx_size += n_audio_ctx*n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_pe;
|
|
|
|
|
|
|
|
|
|
ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
|
|
|
|
|
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b
|
|
|
|
|
ctx_size += 3*n_mels*n_audio_state*ggml_type_sizef(GGML_TYPE_F16); // e_conv_1_w
|
|
|
|
|
ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_1_b
|
|
|
|
|
|
|
|
|
|
ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w
|
|
|
|
|
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b
|
|
|
|
|
ctx_size += 3*n_audio_state*n_audio_state*ggml_type_sizef(GGML_TYPE_F16); // e_conv_2_w
|
|
|
|
|
ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_2_b
|
|
|
|
|
|
|
|
|
|
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w;
|
|
|
|
|
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b;
|
|
|
|
|
ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_w;
|
|
|
|
|
ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_b;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// decoder
|
|
|
|
|
{
|
|
|
|
|
ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
|
|
|
|
|
ctx_size += n_text_ctx*n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_pe;
|
|
|
|
|
|
|
|
|
|
ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
|
|
|
|
|
ctx_size += n_vocab*n_text_state*ggml_type_sizef(wtype); // d_te;
|
|
|
|
|
|
|
|
|
|
ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w;
|
|
|
|
|
ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b;
|
|
|
|
|
ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_w;
|
|
|
|
|
ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_b;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// encoder layers
|
|
|
|
|
{
|
|
|
|
|
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
|
|
|
|
|
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
|
|
|
|
|
ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
|
|
|
|
|
ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
|
|
|
|
|
|
|
|
|
|
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w
|
|
|
|
|
ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
|
|
|
|
|
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_0_w
|
|
|
|
|
ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
|
|
|
|
|
|
|
|
|
|
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w
|
|
|
|
|
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
|
|
|
|
|
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_1_w
|
|
|
|
|
ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
|
|
|
|
|
|
|
|
|
|
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
|
|
|
|
|
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
|
|
|
|
|
ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
|
|
|
|
|
ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
|
|
|
|
|
|
|
|
|
|
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w
|
|
|
|
|
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
|
|
|
|
|
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_q_w
|
|
|
|
|
ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
|
|
|
|
|
|
|
|
|
|
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w
|
|
|
|
|
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_k_w
|
|
|
|
|
|
|
|
|
|
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w
|
|
|
|
|
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
|
|
|
|
|
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_v_w
|
|
|
|
|
ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
|
|
|
|
|
|
|
|
|
|
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w
|
|
|
|
|
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
|
|
|
|
|
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_ln_1_w
|
|
|
|
|
ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// decoder layers
|
|
|
|
|
{
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
|
|
|
|
|
|
|
|
|
|
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w
|
|
|
|
|
ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
|
|
|
|
|
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_0_w
|
|
|
|
|
ctx_size += n_text_layer*( 4*n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
|
|
|
|
|
|
|
|
|
|
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w
|
|
|
|
|
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
|
|
|
|
|
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_1_w
|
|
|
|
|
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
|
|
|
|
|
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
|
|
|
|
|
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w
|
|
|
|
|
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_q_w
|
|
|
|
|
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
|
|
|
|
|
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_k_w
|
|
|
|
|
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w
|
|
|
|
|
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_v_w
|
|
|
|
|
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
|
|
|
|
|
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w
|
|
|
|
|
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_ln_1_w
|
|
|
|
|
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
|
|
|
|
|
//
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_w
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_b
|
|
|
|
|
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w
|
|
|
|
|
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_q_w
|
|
|
|
|
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_q_b
|
|
|
|
|
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_k_w
|
|
|
|
|
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w
|
|
|
|
|
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_v_w
|
|
|
|
|
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_v_b
|
|
|
|
|
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w
|
|
|
|
|
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
|
|
|
|
|
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_ln_1_w
|
|
|
|
|
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_1_b
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
|
|
|
|
@ -1094,10 +1129,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
|
{
|
|
|
|
|
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
|
|
|
|
|
|
|
|
|
|
model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state);
|
|
|
|
|
model.e_conv_1_w = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 3, n_mels, n_audio_state);
|
|
|
|
|
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
|
|
|
|
|
|
|
|
|
model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state);
|
|
|
|
|
model.e_conv_2_w = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 3, n_audio_state, n_audio_state);
|
|
|
|
|
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
|
|
|
|
|
|
|
|
|
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
|
|
|
@ -1313,9 +1348,21 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
|
|
|
|
|
size_t bpe = 0;
|
|
|
|
|
|
|
|
|
|
switch (ftype) {
|
|
|
|
|
case 0: bpe = ggml_type_size(GGML_TYPE_F32); break;
|
|
|
|
|
case 1: bpe = ggml_type_size(GGML_TYPE_F16); break;
|
|
|
|
|
case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
|
|
|
|
|
case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
|
|
|
|
|
default:
|
|
|
|
|
{
|
|
|
|
|
fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if (nelements*bpe != ggml_nbytes(tensor)) {
|
|
|
|
|
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
|
|
|
|
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
|
|
|
|
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
|
|
|
|
return false;
|
|
|
|
@ -1513,14 +1560,14 @@ static bool whisper_encode(
|
|
|
|
|
ggml_permute(ctx0,
|
|
|
|
|
ggml_cpy(ctx0,
|
|
|
|
|
Qcur,
|
|
|
|
|
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
|
|
|
|
ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
|
|
|
|
0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * K =
|
|
|
|
|
ggml_permute(ctx0,
|
|
|
|
|
ggml_cpy(ctx0,
|
|
|
|
|
Kcur,
|
|
|
|
|
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
|
|
|
|
ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
|
|
|
|
0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * V =
|
|
|
|
@ -1530,7 +1577,7 @@ static bool whisper_encode(
|
|
|
|
|
Vcur,
|
|
|
|
|
n_state/n_head, n_head, n_ctx),
|
|
|
|
|
1, 2, 0, 3),
|
|
|
|
|
ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head)
|
|
|
|
|
ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
|
|
|
|
@ -1546,7 +1593,7 @@ static bool whisper_encode(
|
|
|
|
|
ggml_permute(ctx0,
|
|
|
|
|
ggml_cpy(ctx0,
|
|
|
|
|
Kcur,
|
|
|
|
|
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
|
|
|
|
ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
|
|
|
|
0, 2, 1, 3);
|
|
|
|
|
|
|
|
|
|
// K * Q
|
|
|
|
@ -1564,7 +1611,7 @@ static bool whisper_encode(
|
|
|
|
|
// ggml_permute(ctx0,
|
|
|
|
|
// ggml_cpy(ctx0,
|
|
|
|
|
// Vcur,
|
|
|
|
|
// ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
|
|
|
|
|
// ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
|
|
|
|
|
// 1, 2, 0, 3);
|
|
|
|
|
|
|
|
|
|
//struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
|
|
|
@ -1576,7 +1623,7 @@ static bool whisper_encode(
|
|
|
|
|
Vcur,
|
|
|
|
|
n_state/n_head, n_head, n_ctx),
|
|
|
|
|
0, 2, 1, 3),
|
|
|
|
|
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head)
|
|
|
|
|
ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
|
|
|
|
@ -1634,7 +1681,7 @@ static bool whisper_encode(
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|
|
|
|
|
|
cur = ggml_flash_ff(ctx0,
|
|
|
|
|
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)),
|
|
|
|
|
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_state, n_ctx)),
|
|
|
|
|
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
|
|
|
|
#else
|
|
|
|
|
wctx.use_buf(ctx0, 0);
|
|
|
|
|