@ -424,6 +424,9 @@ struct whisper_context {
int64_t t_last ;
whisper_token tid_last ;
std : : vector < float > energy ; // PCM signal energy
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx ; // 0 - use default
} ;
// load the model from a ggml file
@ -1076,13 +1079,11 @@ static bool whisper_encode(
const auto & mel_inp = wctx . mel ;
const auto & hparams = model . hparams ;
const int n_ctx = hparams. n_audio_ctx ;
const int n_ctx = wctx. exp_n_audio_ctx > 0 ? wctx . exp_n_audio_ctx : hparams. n_audio_ctx ;
const int n_state = hparams . n_audio_state ;
const int n_head = hparams . n_audio_head ;
const int n_layer = hparams . n_audio_layer ;
const int N = n_ctx ;
const int n_mels = hparams . n_mels ;
assert ( mel_inp . n_mel = = n_mels ) ;
@ -1132,7 +1133,30 @@ static bool whisper_encode(
cur = ggml_gelu ( ctx0 , cur ) ;
}
cur = ggml_add ( ctx0 , model . e_pe , ggml_transpose ( ctx0 , cur ) ) ;
// ===================================================================
// NOTE: experimenting with partial evaluation of the encoder (ignore)
//static int iter = -1;
//const int n_iter = 1500/n_ctx;
//iter = (iter + 1) % n_iter;
//if (iter == 0) {
// memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
// memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
//}
static int iter = 0 ;
const size_t e_pe_stride = model . e_pe - > ne [ 0 ] * ggml_element_size ( model . e_pe ) ;
const size_t e_pe_offset = model . e_pe - > ne [ 0 ] * ggml_element_size ( model . e_pe ) * n_ctx * iter ;
struct ggml_tensor * e_pe = ggml_view_2d ( ctx0 , model . e_pe , model . e_pe - > ne [ 0 ] , n_ctx , e_pe_stride , e_pe_offset ) ;
cur = ggml_add ( ctx0 , e_pe , ggml_transpose ( ctx0 , cur ) ) ;
// ===================================================================
// original:
//cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
struct ggml_tensor * inpL = cur ;
@ -1198,14 +1222,14 @@ static bool whisper_encode(
ggml_permute ( ctxL ,
ggml_cpy ( ctxL ,
Qcur ,
ggml_new_tensor_3d ( ctxL , GGML_TYPE_F16 , n_state / n_head , n_head , N ) ) ,
ggml_new_tensor_3d ( ctxL , GGML_TYPE_F16 , n_state / n_head , n_head , n_ctx ) ) ,
0 , 2 , 1 , 3 ) ;
struct ggml_tensor * K =
ggml_permute ( ctxL ,
ggml_cpy ( ctxL ,
Kcur ,
ggml_new_tensor_3d ( ctxL , GGML_TYPE_F16 , n_state / n_head , n_head , N ) ) ,
ggml_new_tensor_3d ( ctxL , GGML_TYPE_F16 , n_state / n_head , n_head , n_ctx ) ) ,
0 , 2 , 1 , 3 ) ;
struct ggml_tensor * V =
@ -1213,9 +1237,9 @@ static bool whisper_encode(
ggml_permute ( ctxL ,
ggml_reshape_3d ( ctxL ,
Vcur ,
n_state / n_head , n_head , N ) ,
n_state / n_head , n_head , n_ctx ) ,
1 , 2 , 0 , 3 ) ,
ggml_new_tensor_3d ( ctxL , GGML_TYPE_F16 , N , n_state / n_head , n_head )
ggml_new_tensor_3d ( ctxL , GGML_TYPE_F16 , n_ctx , n_state / n_head , n_head )
) ;
struct ggml_tensor * KQV = ggml_flash_attn ( ctxL , Q , K , V , false ) ;
@ -1224,14 +1248,14 @@ static bool whisper_encode(
ggml_permute ( ctxL ,
ggml_cpy ( ctxL ,
Qcur ,
ggml_new_tensor_3d ( ctxL , GGML_TYPE_F32 , n_state / n_head , n_head , N ) ) ,
ggml_new_tensor_3d ( ctxL , GGML_TYPE_F32 , n_state / n_head , n_head , n_ctx ) ) ,
0 , 2 , 1 , 3 ) ;
struct ggml_tensor * K =
ggml_permute ( ctxL ,
ggml_cpy ( ctxL ,
Kcur ,
ggml_new_tensor_3d ( ctxL , GGML_TYPE_F16 , n_state / n_head , n_head , N ) ) ,
ggml_new_tensor_3d ( ctxL , GGML_TYPE_F16 , n_state / n_head , n_head , n_ctx ) ) ,
0 , 2 , 1 , 3 ) ;
// K * Q
@ -1249,7 +1273,7 @@ static bool whisper_encode(
// ggml_permute(ctxL,
// ggml_cpy(ctxL,
// Vcur,
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N )),
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx )),
// 1, 2, 0, 3);
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
@ -1259,9 +1283,9 @@ static bool whisper_encode(
ggml_permute ( ctxL ,
ggml_reshape_3d ( ctxL ,
Vcur ,
n_state / n_head , n_head , N ) ,
n_state / n_head , n_head , n_ctx ) ,
0 , 2 , 1 , 3 ) ,
ggml_new_tensor_3d ( ctxL , GGML_TYPE_F16 , n_state / n_head , N , n_head )
ggml_new_tensor_3d ( ctxL , GGML_TYPE_F16 , n_state / n_head , n_ctx , n_head )
) ;
struct ggml_tensor * KQV = ggml_mul_mat ( ctxL , ggml_transpose ( ctxL , V ) , KQ_soft_max ) ;
@ -1271,7 +1295,7 @@ static bool whisper_encode(
cur = ggml_cpy ( ctxL ,
KQV_merged ,
ggml_new_tensor_2d ( ctxL , GGML_TYPE_F32 , n_state , N ) ) ;
ggml_new_tensor_2d ( ctxL , GGML_TYPE_F32 , n_state , n_ctx ) ) ;
}
// projection
@ -1425,6 +1449,8 @@ static bool whisper_encode(
Vcross ) ,
Vcross ) ;
//struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
//struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
struct ggml_tensor * k = ggml_view_1d ( ctx0 , model . memory_cross_k , n_state * n_ctx , ( ggml_element_size ( model . memory_cross_k ) * n_state ) * ( il * n_ctx ) ) ;
struct ggml_tensor * v = ggml_view_1d ( ctx0 , model . memory_cross_v , n_state * n_ctx , ( ggml_element_size ( model . memory_cross_v ) * n_state ) * ( il * n_ctx ) ) ;
@ -1474,7 +1500,7 @@ static bool whisper_decode(
const int n_layer = hparams . n_text_layer ;
const int N = n_tokens ;
const int M = hparams. n_audio_ctx ;
const int M = wctx. exp_n_audio_ctx > 0 ? wctx . exp_n_audio_ctx : hparams. n_audio_ctx ;
struct ggml_init_params params = {
. mem_size = wctx . buf_compute . size ( ) ,
@ -2365,6 +2391,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.translate =*/ false ,
/*.no_context =*/ false ,
/*.single_segment =*/ false ,
/*.print_special_tokens =*/ false ,
/*.print_progress =*/ true ,
/*.print_realtime =*/ false ,
@ -2374,8 +2401,10 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.thold_pt =*/ 0.01f ,
/*.thold_ptsum =*/ 0.01f ,
/*.max_len =*/ 0 ,
/*.max_tokens =*/ 0 ,
/*.speed_up =*/ false ,
/*.audio_ctx =*/ 0 ,
/*.language =*/ " en " ,
@ -2405,6 +2434,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.translate =*/ false ,
/*.no_context =*/ false ,
/*.single_segment =*/ false ,
/*.print_special_tokens =*/ false ,
/*.print_progress =*/ true ,
/*.print_realtime =*/ false ,
@ -2414,8 +2444,10 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.thold_pt =*/ 0.01f ,
/*.thold_ptsum =*/ 0.01f ,
/*.max_len =*/ 0 ,
/*.max_tokens =*/ 0 ,
/*.speed_up =*/ false ,
/*.audio_ctx =*/ 0 ,
/*.language =*/ " en " ,
@ -2546,6 +2578,9 @@ int whisper_full(
prompt_past . clear ( ) ;
}
// overwrite audio_ctx
ctx - > exp_n_audio_ctx = params . audio_ctx ;
// these tokens determine the task that will be performed
std : : vector < whisper_token > prompt_init = { whisper_token_sot ( ctx ) } ;
if ( whisper_is_multilingual ( ctx ) ) {
@ -2656,7 +2691,7 @@ int whisper_full(
//}
// end of text token
if ( token . id = = whisper_token_eot ( ctx ) ) {
if ( token . id = = whisper_token_eot ( ctx ) | | ( params . max_tokens > 0 & & i > params . max_tokens ) ) {
if ( result_len = = 0 ) {
if ( seek + seek_delta + 100 > = seek_end ) {
result_len = i + 1 ;
@ -2665,6 +2700,12 @@ int whisper_full(
fprintf ( stderr , " \n %s: failed to generate timestamp token - this should not happen \n \n " , __func__ ) ;
}
}
if ( params . single_segment ) {
result_len = i + 1 ;
seek_delta = 100 * WHISPER_CHUNK_SIZE ;
}
break ;
}