diff --git a/.clang-format b/.clang-format new file mode 100755 index 0000000..f543efa --- /dev/null +++ b/.clang-format @@ -0,0 +1,92 @@ +# Options are listed here: +# https://clang.llvm.org/docs/ClangFormatStyleOptions.html +--- +AccessModifierOffset: -3 +AlignAfterOpenBracket: BlockIndent +AlignArrayOfStructures: Left +AlwaysBreakTemplateDeclarations: Yes +AlignConsecutiveAssignments: true +AlignConsecutiveDeclarations: false +AlignConsecutiveMacros: true +AlignEscapedNewlines: Right +AlignOperands: true +AlignTrailingComments: true +AllowAllArgumentsOnNextLine: false +AllowAllConstructorInitializersOnNextLine: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: None +AllowShortLambdasOnASingleLine: All +AllowShortIfStatementsOnASingleLine: Never +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: MultiLine +BinPackArguments: false +BinPackParameters: false +BitFieldColonSpacing: Both +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeConceptDeclarations: Always +BreakBeforeInheritanceComma: false +BreakInheritanceList: BeforeColon +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: true +BreakConstructorInitializers: AfterColon +BreakStringLiterals: true +ColumnLimit: 120 +CompactNamespaces: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: false +DeriveLineEnding: false +DerivePointerAlignment: false +EmptyLineBeforeAccessModifier: LogicalBlock +EmptyLineAfterAccessModifier: Never +FixNamespaceComments: false +IncludeBlocks: Preserve +IncludeIsMainRegex: '(Test)?$' +IndentCaseLabels: true +IndentPPDirectives: AfterHash +IndentWidth: 4 +IndentWrappedFunctionNames: true +InsertBraces: true +KeepEmptyLinesAtTheStartOfBlocks: true +Language: Cpp +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: All +PackConstructorInitializers: Never +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 1000 +PointerAlignment: Middle +ReflowComments: true +SortIncludes: true +SortUsingDeclarations: true +SpaceAfterCStyleCast: true +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCaseColon: false +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: c++11 +TabWidth: 4 +UseCRLF: false +UseTab: Never diff --git a/ggml.c b/ggml.c index ddcdea5..1774de8 100644 --- a/ggml.c +++ b/ggml.c @@ -173,12 +173,12 @@ static inline float fp32_from_bits(uint32_t w) { } static inline uint32_t fp32_to_bits(float f) { - union { - float as_value; - uint32_t as_bits; - } fp32; - fp32.as_value = f; - return fp32.as_bits; + union { + float as_value; + uint32_t as_bits; + } fp32; + fp32.as_value = f; + return fp32.as_bits; } float ggml_fp16_to_fp32(ggml_fp16_t h) { diff --git a/ggml.h b/ggml.h index a217d2d..1d915da 100644 --- a/ggml.h +++ b/ggml.h @@ -169,13 +169,13 @@ // // -#ifdef __cplusplus +#ifdef __cplusplus extern "C" { #endif -#include -#include #include +#include +#include #define GGML_MAX_DIMS 4 #define GGML_MAX_NODES 4096 @@ -191,7 +191,7 @@ typedef uint16_t ggml_fp16_t; #endif // convert FP16 <-> FP32 -float ggml_fp16_to_fp32(ggml_fp16_t x); +float ggml_fp16_to_fp32(ggml_fp16_t x); ggml_fp16_t ggml_fp32_to_fp16(float x); struct ggml_object; @@ -253,8 +253,8 @@ enum ggml_op { struct ggml_tensor { enum ggml_type type; - int n_dims; - int ne[GGML_MAX_DIMS]; // number of elements + int n_dims; + int ne[GGML_MAX_DIMS]; // number of elements size_t nb[GGML_MAX_DIMS]; // stride in bytes: // nb[0] = sizeof(type) // nb[1] = nb[0] * ne[0] + padding @@ -274,7 +274,7 @@ struct ggml_tensor { int n_tasks; // performance - int perf_runs; + int perf_runs; int64_t perf_cycles; int64_t perf_time_us; @@ -296,7 +296,7 @@ struct ggml_cgraph { struct ggml_tensor * leafs[GGML_MAX_NODES]; // performance - int perf_runs; + int perf_runs; int64_t perf_cycles; int64_t perf_time_us; }; @@ -307,19 +307,19 @@ struct ggml_init_params { void * mem_buffer; // if NULL, memory will be allocated internally }; -void ggml_time_init(void); // call this once at the beginning of the program +void ggml_time_init(void); // call this once at the beginning of the program int64_t ggml_time_ms(void); int64_t ggml_time_us(void); int64_t ggml_cycles(void); int64_t ggml_cycles_per_ms(void); -void ggml_print_object (const struct ggml_object * obj); +void ggml_print_object(const struct ggml_object * obj); void ggml_print_objects(const struct ggml_context * ctx); -int ggml_nelements(const struct ggml_tensor * tensor); -size_t ggml_nbytes (const struct ggml_tensor * tensor); +int ggml_nelements(const struct ggml_tensor * tensor); +size_t ggml_nbytes(const struct ggml_tensor * tensor); -size_t ggml_type_size (enum ggml_type type); +size_t ggml_type_size(enum ggml_type type); size_t ggml_element_size(const struct ggml_tensor * tensor); struct ggml_context * ggml_init(struct ggml_init_params params); @@ -327,290 +327,192 @@ void ggml_free(struct ggml_context * ctx); size_t ggml_used_mem(const struct ggml_context * ctx); -struct ggml_tensor * ggml_new_tensor( - struct ggml_context * ctx, - enum ggml_type type, - int n_dims, - const int *ne); - -struct ggml_tensor * ggml_new_tensor_1d( - struct ggml_context * ctx, - enum ggml_type type, - int ne0); - -struct ggml_tensor * ggml_new_tensor_2d( - struct ggml_context * ctx, - enum ggml_type type, - int ne0, - int ne1); - -struct ggml_tensor * ggml_new_tensor_3d( - struct ggml_context * ctx, - enum ggml_type type, - int ne0, - int ne1, - int ne2); +struct ggml_tensor * ggml_new_tensor(struct ggml_context * ctx, enum ggml_type type, int n_dims, const int * ne); + +struct ggml_tensor * ggml_new_tensor_1d(struct ggml_context * ctx, enum ggml_type type, int ne0); + +struct ggml_tensor * ggml_new_tensor_2d(struct ggml_context * ctx, enum ggml_type type, int ne0, int ne1); + +struct ggml_tensor * ggml_new_tensor_3d(struct ggml_context * ctx, enum ggml_type type, int ne0, int ne1, int ne2); struct ggml_tensor * ggml_new_tensor_4d( - struct ggml_context * ctx, - enum ggml_type type, - int ne0, - int ne1, - int ne2, - int ne3); + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1, + int ne2, + int ne3 +); struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value); struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); -struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); +struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src); struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src); struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); -struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); -struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); +struct ggml_tensor * ggml_set_i32(struct ggml_tensor * tensor, int32_t value); +struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value); int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i); -void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); +void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); -void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); +void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); - void * ggml_get_data (const struct ggml_tensor * tensor); +void * ggml_get_data(const struct ggml_tensor * tensor); float * ggml_get_data_f32(const struct ggml_tensor * tensor); // // operations on tensors with backpropagation // -struct ggml_tensor * ggml_dup( - struct ggml_context * ctx, - struct ggml_tensor * a); +struct ggml_tensor * ggml_dup(struct ggml_context * ctx, struct ggml_tensor * a); -struct ggml_tensor * ggml_add( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); +struct ggml_tensor * ggml_add(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); -struct ggml_tensor * ggml_sub( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); +struct ggml_tensor * ggml_sub(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); -struct ggml_tensor * ggml_mul( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); +struct ggml_tensor * ggml_mul(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); -struct ggml_tensor * ggml_div( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); +struct ggml_tensor * ggml_div(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); -struct ggml_tensor * ggml_sqr( - struct ggml_context * ctx, - struct ggml_tensor * a); +struct ggml_tensor * ggml_sqr(struct ggml_context * ctx, struct ggml_tensor * a); -struct ggml_tensor * ggml_sqrt( - struct ggml_context * ctx, - struct ggml_tensor * a); +struct ggml_tensor * ggml_sqrt(struct ggml_context * ctx, struct ggml_tensor * a); // return scalar // TODO: compute sum along rows -struct ggml_tensor * ggml_sum( - struct ggml_context * ctx, - struct ggml_tensor * a); +struct ggml_tensor * ggml_sum(struct ggml_context * ctx, struct ggml_tensor * a); // mean along rows -struct ggml_tensor * ggml_mean( - struct ggml_context * ctx, - struct ggml_tensor * a); +struct ggml_tensor * ggml_mean(struct ggml_context * ctx, struct ggml_tensor * a); // if a is the same shape as b, and a is not parameter, return a // otherwise, return a new tensor: repeat(a) to fit in b -struct ggml_tensor * ggml_repeat( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); +struct ggml_tensor * ggml_repeat(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); -struct ggml_tensor * ggml_abs( - struct ggml_context * ctx, - struct ggml_tensor * a); +struct ggml_tensor * ggml_abs(struct ggml_context * ctx, struct ggml_tensor * a); -struct ggml_tensor * ggml_sgn( - struct ggml_context * ctx, - struct ggml_tensor * a); +struct ggml_tensor * ggml_sgn(struct ggml_context * ctx, struct ggml_tensor * a); -struct ggml_tensor * ggml_neg( - struct ggml_context * ctx, - struct ggml_tensor * a); +struct ggml_tensor * ggml_neg(struct ggml_context * ctx, struct ggml_tensor * a); -struct ggml_tensor * ggml_step( - struct ggml_context * ctx, - struct ggml_tensor * a); +struct ggml_tensor * ggml_step(struct ggml_context * ctx, struct ggml_tensor * a); -struct ggml_tensor * ggml_relu( - struct ggml_context * ctx, - struct ggml_tensor * a); +struct ggml_tensor * ggml_relu(struct ggml_context * ctx, struct ggml_tensor * a); // TODO: double-check this computation is correct -struct ggml_tensor * ggml_gelu( - struct ggml_context * ctx, - struct ggml_tensor * a); +struct ggml_tensor * ggml_gelu(struct ggml_context * ctx, struct ggml_tensor * a); // normalize along rows // TODO: eps is hardcoded to 1e-5 for now -struct ggml_tensor * ggml_norm( - struct ggml_context * ctx, - struct ggml_tensor * a); +struct ggml_tensor * ggml_norm(struct ggml_context * ctx, struct ggml_tensor * a); // A: m rows, n columns // B: p rows, n columns (i.e. we transpose it internally) // result is m columns, p rows -struct ggml_tensor * ggml_mul_mat( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); +struct ggml_tensor * ggml_mul_mat(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); // // operations on tensors without backpropagation // // in-place, returns view(a) -struct ggml_tensor * ggml_scale( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); +struct ggml_tensor * ggml_scale(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); // a -> b, return view(b) -struct ggml_tensor * ggml_cpy( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); +struct ggml_tensor * ggml_cpy(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); // return view(a), b specifies the new shape // TODO: when we start computing gradient, make a copy instead of view -struct ggml_tensor * ggml_reshape( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); +struct ggml_tensor * ggml_reshape(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); // return view(a) // TODO: when we start computing gradient, make a copy instead of view -struct ggml_tensor * ggml_reshape_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int ne0, - int ne1); +struct ggml_tensor * ggml_reshape_2d(struct ggml_context * ctx, struct ggml_tensor * a, int ne0, int ne1); // return view(a) // TODO: when we start computing gradient, make a copy instead of view -struct ggml_tensor * ggml_reshape_3d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int ne0, - int ne1, - int ne2); +struct ggml_tensor * ggml_reshape_3d(struct ggml_context * ctx, struct ggml_tensor * a, int ne0, int ne1, int ne2); // offset in bytes -struct ggml_tensor * ggml_view_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int ne0, - size_t offset); +struct ggml_tensor * ggml_view_1d(struct ggml_context * ctx, struct ggml_tensor * a, int ne0, size_t offset); struct ggml_tensor * ggml_view_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int ne0, - int ne1, - size_t nb1, // row stride in bytes - size_t offset); + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + size_t nb1, // row stride in bytes + size_t offset +); struct ggml_tensor * ggml_permute( - struct ggml_context * ctx, - struct ggml_tensor * a, - int axis0, - int axis1, - int axis2, - int axis3); + struct ggml_context * ctx, + struct ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3 +); // alias for ggml_permute(ctx, a, 1, 0, 2, 3) -struct ggml_tensor * ggml_transpose( - struct ggml_context * ctx, - struct ggml_tensor * a); +struct ggml_tensor * ggml_transpose(struct ggml_context * ctx, struct ggml_tensor * a); -struct ggml_tensor * ggml_get_rows( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); +struct ggml_tensor * ggml_get_rows(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); // set elements above the diagonal to -INF // in-place, returns view(a) -struct ggml_tensor * ggml_diag_mask_inf( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past); +struct ggml_tensor * ggml_diag_mask_inf(struct ggml_context * ctx, struct ggml_tensor * a, int n_past); // in-place, returns view(a) -struct ggml_tensor * ggml_soft_max( - struct ggml_context * ctx, - struct ggml_tensor * a); +struct ggml_tensor * ggml_soft_max(struct ggml_context * ctx, struct ggml_tensor * a); // rotary position embedding // in-place, returns view(a) // if mode == 1, skip n_past elements // TODO: avoid creating a new tensor every time -struct ggml_tensor * ggml_rope( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past, - int n_dims, - int mode); +struct ggml_tensor * ggml_rope(struct ggml_context * ctx, struct ggml_tensor * a, int n_past, int n_dims, int mode); // padding = 1 // TODO: we don't support extra parameters for now // that's why we are hard-coding the stride, padding, and dilation // not great .. -struct ggml_tensor * ggml_conv_1d_1s( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); +struct ggml_tensor * ggml_conv_1d_1s(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); -struct ggml_tensor * ggml_conv_1d_2s( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); +struct ggml_tensor * ggml_conv_1d_2s(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); struct ggml_tensor * ggml_flash_attn( - struct ggml_context * ctx, - struct ggml_tensor * q, - struct ggml_tensor * k, - struct ggml_tensor * v, - bool masked); + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + bool masked +); struct ggml_tensor * ggml_flash_ff( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b0, - struct ggml_tensor * b1, - struct ggml_tensor * c0, - struct ggml_tensor * c1); + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b0, + struct ggml_tensor * b1, + struct ggml_tensor * c0, + struct ggml_tensor * c1 +); // // automatic differentiation // -void ggml_set_param( - struct ggml_context * ctx, - struct ggml_tensor * tensor); +void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor); void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); -struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor); +struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor); struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep); void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph); -void ggml_graph_reset (struct ggml_cgraph * cgraph); +void ggml_graph_reset(struct ggml_cgraph * cgraph); // print info and performance information for the graph void ggml_graph_print(const struct ggml_cgraph * cgraph); @@ -699,8 +601,8 @@ struct ggml_opt_params { int n_iter; int max_linesearch; - float eps; // convergence tolerance - float ftol; // line search tolerance + float eps; // convergence tolerance + float ftol; // line search tolerance float wolfe; float min_step; float max_step; @@ -712,10 +614,7 @@ struct ggml_opt_params { struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type); // optimize the function defined by the tensor f -enum ggml_opt_result ggml_opt( - struct ggml_context * ctx, - struct ggml_opt_params params, - struct ggml_tensor * f); +enum ggml_opt_result ggml_opt(struct ggml_context * ctx, struct ggml_opt_params params, struct ggml_tensor * f); // // system info @@ -732,6 +631,6 @@ int ggml_cpu_has_fp16_va(void); int ggml_cpu_has_wasm_simd(void); int ggml_cpu_has_blas(void); -#ifdef __cplusplus +#ifdef __cplusplus } #endif diff --git a/whisper.cpp b/whisper.cpp index 84c2490..bc7d72d 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -11,13 +11,13 @@ #include #include #include +#include #include #include #include -#include #define USE_FLASH_ATTN -//#define USE_FLASH_FF +// #define USE_FLASH_FF // available whisper models enum e_model { @@ -29,6 +29,7 @@ enum e_model { MODEL_LARGE, }; +// clang-format off static const std::map> g_lang = { { "en", { 0, "english", } }, { "zh", { 1, "chinese", } }, @@ -130,9 +131,11 @@ static const std::map> g_lang = { { "jw", { 97, "javanese", } }, { "su", { 98, "sundanese", } }, }; +// clang-format on -static const size_t MB = 1024*1024; +static const size_t MB = 1024 * 1024; +// clang-format off static const std::map MEM_REQ_MODEL = { { MODEL_TINY, 74ull*MB }, { MODEL_BASE, 142ull*MB }, @@ -180,6 +183,7 @@ static const std::map MEM_REQ_DECODE_LAYER = { { MODEL_MEDIUM, 84ull*MB }, { MODEL_LARGE, 110ull*MB }, }; +// clang-format on struct whisper_mel { int n_len; @@ -408,9 +412,9 @@ struct whisper_context { int64_t t_start_us = 0; std::vector * buf_model; // the model buffer is read-only and can be shared between processors - std::vector buf_memory; - std::vector buf_compute; - std::vector buf_compute_layer; + std::vector buf_memory; + std::vector buf_compute; + std::vector buf_compute_layer; whisper_model model; whisper_vocab vocab; @@ -434,10 +438,8 @@ struct whisper_context { int32_t exp_n_audio_ctx; // 0 - use default }; -template -static void read_safe(std::ifstream& fin, T& dest) -{ - fin.read((char*)& dest, sizeof(T)); +template static void read_safe(std::ifstream & fin, T & dest) { + fin.read((char *) &dest, sizeof(T)); } // load the model from a ggml file @@ -473,7 +475,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx } } - //load hparams + // load hparams { auto & hparams = model.hparams; @@ -528,7 +530,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type)); wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type)); wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); - wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); + wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)) + ); } // load mel filters @@ -547,11 +550,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx int32_t n_vocab = 0; read_safe(fin, n_vocab); - //if (n_vocab != model.hparams.n_vocab) { - // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", - // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); - // return false; - //} + // if (n_vocab != model.hparams.n_vocab) { + // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", + // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); + // return false; + // } std::string word; std::vector tmp; @@ -568,14 +571,14 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx word.assign(&tmp[0], tmp.size()); } else { // seems like we have an empty-string token in multi-language models (i = 50256) - //fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + // fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i); word = ""; } vocab.token_to_id[word] = i; - vocab.id_to_token[i] = word; + vocab.id_to_token[i] = word; - //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); + // printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); } vocab.n_vocab = model.hparams.n_vocab; @@ -607,12 +610,12 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx word = "[_extra_token_" + std::to_string(i) + "]"; } vocab.token_to_id[word] = i; - vocab.id_to_token[i] = word; + vocab.id_to_token[i] = word; } } - wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); - wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx); + wctx.logits.reserve(vocab.n_vocab * model.hparams.n_text_ctx); + wctx.probs.reserve(vocab.n_vocab * model.hparams.n_text_ctx); vocab.probs_id.reserve(n_vocab); } @@ -620,10 +623,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx { // this is the total memory required to run the inference const size_t mem_required = - wctx.buf_model->size() + - wctx.buf_memory.size() + - wctx.buf_compute.size() + - wctx.buf_compute_layer.size(); + wctx.buf_model->size() + wctx.buf_memory.size() + wctx.buf_compute.size() + wctx.buf_compute_layer.size(); fprintf(stderr, "%s: mem_required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); } @@ -652,98 +652,98 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx // encoder { // TODO: F16 .. maybe not? - ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe; + ctx_size += n_audio_ctx * n_audio_state * ggml_type_size(GGML_TYPE_F32); // e_pe; - ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w - ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b + ctx_size += 3 * n_mels * n_audio_state * ggml_type_size(wtype); // e_conv_1_w + ctx_size += n_audio_state * ggml_type_size(GGML_TYPE_F32); // e_conv_1_b - ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w - ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b + ctx_size += 3 * n_audio_state * n_audio_state * ggml_type_size(wtype); // e_conv_2_w + ctx_size += n_audio_state * ggml_type_size(GGML_TYPE_F32); // e_conv_2_b - ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w; - ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b; + ctx_size += n_audio_state * ggml_type_size(GGML_TYPE_F32); // e_ln_w; + ctx_size += n_audio_state * ggml_type_size(GGML_TYPE_F32); // e_ln_b; } // decoder { // TODO: F16 .. maybe not? - ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe; + ctx_size += n_text_ctx * n_text_state * ggml_type_size(GGML_TYPE_F32); // d_pe; - ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te; + ctx_size += n_vocab * n_text_state * ggml_type_size(wtype); // d_te; - ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w; - ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b; + ctx_size += n_text_state * ggml_type_size(GGML_TYPE_F32); // d_ln_w; + ctx_size += n_text_state * ggml_type_size(GGML_TYPE_F32); // d_ln_b; } // encoder layers { - ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w - ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b + ctx_size += n_audio_layer * (n_audio_state * ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w + ctx_size += n_audio_layer * (n_audio_state * ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b - ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w - ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b + ctx_size += n_audio_layer * (4 * n_audio_state * n_audio_state * ggml_type_size(wtype)); // mlp_0_w + ctx_size += n_audio_layer * (4 * n_audio_state * ggml_type_size(GGML_TYPE_F32)); // mlp_0_b - ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b + ctx_size += n_audio_layer * (4 * n_audio_state * n_audio_state * ggml_type_size(wtype)); // mlp_1_w + ctx_size += n_audio_layer * (n_audio_state * ggml_type_size(GGML_TYPE_F32)); // mlp_1_b - ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w - ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b + ctx_size += n_audio_layer * (n_audio_state * ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w + ctx_size += n_audio_layer * (n_audio_state * ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b + ctx_size += n_audio_layer * (n_audio_state * n_audio_state * ggml_type_size(wtype)); // attn_q_w + ctx_size += n_audio_layer * (n_audio_state * ggml_type_size(GGML_TYPE_F32)); // attn_q_b - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w + ctx_size += n_audio_layer * (n_audio_state * n_audio_state * ggml_type_size(wtype)); // attn_k_w - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b + ctx_size += n_audio_layer * (n_audio_state * n_audio_state * ggml_type_size(wtype)); // attn_v_w + ctx_size += n_audio_layer * (n_audio_state * ggml_type_size(GGML_TYPE_F32)); // attn_v_b - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b + ctx_size += n_audio_layer * (n_audio_state * n_audio_state * ggml_type_size(wtype)); // attn_ln_1_w + ctx_size += n_audio_layer * (n_audio_state * ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b } // decoder layers { - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b + ctx_size += n_text_layer * (n_text_state * ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w + ctx_size += n_text_layer * (n_text_state * ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b - ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w - ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b + ctx_size += n_text_layer * (4 * n_text_state * n_text_state * ggml_type_size(wtype)); // mlp_0_w + ctx_size += n_text_layer * (4 * n_text_state * ggml_type_size(GGML_TYPE_F32)); // mlp_0_b - ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b + ctx_size += n_text_layer * (4 * n_text_state * n_text_state * ggml_type_size(wtype)); // mlp_1_w + ctx_size += n_text_layer * (n_text_state * ggml_type_size(GGML_TYPE_F32)); // mlp_1_b - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b + ctx_size += n_text_layer * (n_text_state * ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w + ctx_size += n_text_layer * (n_text_state * ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b + ctx_size += n_text_layer * (n_text_state * n_text_state * ggml_type_size(wtype)); // attn_q_w + ctx_size += n_text_layer * (n_text_state * ggml_type_size(GGML_TYPE_F32)); // attn_q_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w + ctx_size += n_text_layer * (n_text_state * n_text_state * ggml_type_size(wtype)); // attn_k_w - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b + ctx_size += n_text_layer * (n_text_state * n_text_state * ggml_type_size(wtype)); // attn_v_w + ctx_size += n_text_layer * (n_text_state * ggml_type_size(GGML_TYPE_F32)); // attn_v_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b - // - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w - ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b + ctx_size += n_text_layer * (n_text_state * n_text_state * ggml_type_size(wtype)); // attn_ln_1_w + ctx_size += n_text_layer * (n_text_state * ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b + // + ctx_size += n_text_layer * (n_text_state * ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w + ctx_size += n_text_layer * (n_text_state * ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b + ctx_size += n_text_layer * (n_text_state * n_text_state * ggml_type_size(wtype)); // cross_attn_q_w + ctx_size += n_text_layer * (n_text_state * ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w + ctx_size += n_text_layer * (n_text_state * n_text_state * ggml_type_size(wtype)); // cross_attn_k_w - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b + ctx_size += n_text_layer * (n_text_state * n_text_state * ggml_type_size(wtype)); // cross_attn_v_w + ctx_size += n_text_layer * (n_text_state * ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b + ctx_size += n_text_layer * (n_text_state * n_text_state * ggml_type_size(wtype)); // cross_attn_ln_1_w + ctx_size += n_text_layer * (n_text_state * ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b } - ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead + ctx_size += (15 + 15 * n_audio_layer + 24 * n_text_layer) * 256; // object overhead - fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size / (1024.0 * 1024.0)); } // create the ggml context @@ -784,10 +784,10 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx { model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx); - model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state); + model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state); model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); - model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state); + model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state); model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); @@ -811,24 +811,24 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state); - layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state); + layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4 * n_audio_state); + layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4 * n_audio_state); - layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state); - layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4 * n_audio_state, n_audio_state); + layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); - layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); // map by name @@ -880,38 +880,38 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state); - layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state); + layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4 * n_text_state); + layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4 * n_text_state); - layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state); - layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4 * n_text_state, n_text_state); + layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); // map by name @@ -938,19 +938,23 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = + layer.cross_attn_ln_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b; - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = + layer.cross_attn_q_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b; model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = + layer.cross_attn_v_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b; - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = + layer.cross_attn_ln_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b; } } } @@ -980,8 +984,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx // key/value memory for the self-attention layer { - const int n_mem = n_text_layer*n_text_ctx; - const int n_elements = n_text_state*n_mem; + const int n_mem = n_text_layer * n_text_ctx; + const int n_elements = n_text_state * n_mem; model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); @@ -991,18 +995,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx { const int n_audio_ctx = hparams.n_audio_ctx; - const int n_mem = n_text_layer*n_audio_ctx; - const int n_elements = n_text_state*n_mem; + const int n_mem = n_text_layer * n_audio_ctx; + const int n_elements = n_text_state * n_mem; model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); } - const size_t memory_size = - ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) + - ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); + const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) + + ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); - fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } // load weights @@ -1025,7 +1028,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx } int32_t nelements = 1; - int32_t ne[3] = { 1, 1, 1 }; + int32_t ne[3] = { 1, 1, 1 }; for (int i = 0; i < n_dims; ++i) { read_safe(fin, ne[i]); nelements *= ne[i]; @@ -1048,32 +1051,59 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx } if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { - fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", - __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]); + fprintf( + stderr, + "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", + __func__, + name.data(), + tensor->ne[0], + tensor->ne[1], + tensor->ne[2], + ne[0], + ne[1], + ne[2] + ); return false; } const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t); - if (nelements*bpe != ggml_nbytes(tensor)) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", - __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + if (nelements * bpe != ggml_nbytes(tensor)) { + fprintf( + stderr, + "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, + name.data(), + ggml_nbytes(tensor), + nelements * bpe + ); return false; } fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); - //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); + // printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? + // "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); total_size += ggml_nbytes(tensor); model.n_loaded++; } - fprintf(stderr, "%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0); + fprintf(stderr, "%s: model size = %7.2f MB\n", __func__, total_size / 1024.0 / 1024.0); if (model.n_loaded == 0) { - fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + fprintf( + stderr, + "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", + __func__ + ); } else if (model.n_loaded != (int) model.tensors.size()) { - fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); + fprintf( + stderr, + "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", + __func__, + model.tensors.size(), + model.n_loaded + ); return false; } } @@ -1092,10 +1122,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx // - n_threads: number of threads to use // - mel_offset: offset in the mel spectrogram (i.e. audio offset) // -static bool whisper_encode( - whisper_context & wctx, - const int n_threads, - const int mel_offset) { +static bool whisper_encode(whisper_context & wctx, const int n_threads, const int mel_offset) { const auto & model = wctx.model; const auto & mel_inp = wctx.mel; const auto & hparams = model.hparams; @@ -1114,18 +1141,18 @@ static bool whisper_encode( struct ggml_context * ctx0 = ggml_init(params); - struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); + struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2 * n_ctx, n_mels); assert(mel->type == GGML_TYPE_F32); { float * dst = (float *) mel->data; memset(dst, 0, ggml_nbytes(mel)); const int i0 = std::min(mel_offset, mel_inp.n_len); - const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len); + const int i1 = std::min(mel_offset + 2 * n_ctx, mel_inp.n_len); for (int j = 0; j < mel_inp.n_mel; ++j) { for (int i = i0; i < i1; ++i) { - dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i]; + dst[j * 2 * n_ctx + (i - i0)] = mel_inp.data[j * mel_inp.n_len + i]; } } } @@ -1135,40 +1162,32 @@ static bool whisper_encode( // convolution + gelu { cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, - model.e_conv_1_b, - cur), - cur); + cur = ggml_add(ctx0, ggml_repeat(ctx0, model.e_conv_1_b, cur), cur); cur = ggml_gelu(ctx0, cur); cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, - model.e_conv_2_b, - cur), - cur); + cur = ggml_add(ctx0, ggml_repeat(ctx0, model.e_conv_2_b, cur), cur); cur = ggml_gelu(ctx0, cur); } // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) - //static int iter = -1; - //const int n_iter = 1500/n_ctx; + // static int iter = -1; + // const int n_iter = 1500/n_ctx; - //iter = (iter + 1) % n_iter; + // iter = (iter + 1) % n_iter; - //if (iter == 0) { - // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); - // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); - //} + // if (iter == 0) { + // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); + // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); + // } static int iter = 0; - const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe); - const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter; + const size_t e_pe_stride = model.e_pe->ne[0] * ggml_element_size(model.e_pe); + const size_t e_pe_offset = model.e_pe->ne[0] * ggml_element_size(model.e_pe) * n_ctx * iter; struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); @@ -1176,7 +1195,7 @@ static bool whisper_encode( // =================================================================== // original: - //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); + // cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); struct ggml_tensor * inpL = cur; @@ -1196,136 +1215,113 @@ static bool whisper_encode( cur = ggml_norm(ctxL, inpL); // cur = ln_0_w*cur + ln_0_b - cur = ggml_add(ctxL, - ggml_mul(ctxL, - ggml_repeat(ctxL, layer.attn_ln_0_w, cur), - cur), - ggml_repeat(ctxL, layer.attn_ln_0_b, cur)); + cur = ggml_add( + ctxL, + ggml_mul(ctxL, ggml_repeat(ctxL, layer.attn_ln_0_w, cur), cur), + ggml_repeat(ctxL, layer.attn_ln_0_b, cur) + ); } // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctxL, - layer.attn_q_w, - cur); + struct ggml_tensor * Qcur = ggml_mul_mat(ctxL, layer.attn_q_w, cur); - Qcur = ggml_add(ctxL, - ggml_repeat(ctxL, - layer.attn_q_b, - Qcur), - Qcur); + Qcur = ggml_add(ctxL, ggml_repeat(ctxL, layer.attn_q_b, Qcur), Qcur); - //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + // Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); // note: no bias for Key - struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, - layer.attn_k_w, - cur); + struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, layer.attn_k_w, cur); - //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + // Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); - struct ggml_tensor * Vcur = ggml_mul_mat(ctxL, - layer.attn_v_w, - cur); + struct ggml_tensor * Vcur = ggml_mul_mat(ctxL, layer.attn_v_w, cur); - Vcur = ggml_add(ctxL, - ggml_repeat(ctxL, - layer.attn_v_b, - Vcur), - Vcur); + Vcur = ggml_add(ctxL, ggml_repeat(ctxL, layer.attn_v_b, Vcur), Vcur); // ------ #ifdef USE_FLASH_ATTN - struct ggml_tensor * Q = - ggml_permute(ctxL, - ggml_cpy(ctxL, - Qcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - struct ggml_tensor * K = - ggml_permute(ctxL, - ggml_cpy(ctxL, - Kcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - struct ggml_tensor * V = - ggml_cpy(ctxL, - ggml_permute(ctxL, - ggml_reshape_3d(ctxL, - Vcur, - n_state/n_head, n_head, n_ctx), - 1, 2, 0, 3), - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head) - ); + struct ggml_tensor * Q = ggml_permute( + ctxL, + ggml_cpy(ctxL, Qcur, ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state / n_head, n_head, n_ctx)), + 0, + 2, + 1, + 3 + ); + + struct ggml_tensor * K = ggml_permute( + ctxL, + ggml_cpy(ctxL, Kcur, ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state / n_head, n_head, n_ctx)), + 0, + 2, + 1, + 3 + ); + + struct ggml_tensor * V = ggml_cpy( + ctxL, + ggml_permute(ctxL, ggml_reshape_3d(ctxL, Vcur, n_state / n_head, n_head, n_ctx), 1, 2, 0, 3), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state / n_head, n_head) + ); struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false); #else - struct ggml_tensor * Q = - ggml_permute(ctxL, - ggml_cpy(ctxL, - Qcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - struct ggml_tensor * K = - ggml_permute(ctxL, - ggml_cpy(ctxL, - Kcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); + struct ggml_tensor * Q = ggml_permute( + ctxL, + ggml_cpy(ctxL, Qcur, ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state / n_head, n_head, n_ctx)), + 0, + 2, + 1, + 3 + ); + + struct ggml_tensor * K = ggml_permute( + ctxL, + ggml_cpy(ctxL, Kcur, ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state / n_head, n_head, n_ctx)), + 0, + 2, + 1, + 3 + ); // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); struct ggml_tensor * KQ_scaled = - ggml_scale(ctxL, - KQ, - ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) - ); + ggml_scale(ctxL, KQ, ggml_new_f32(ctxL, 1.0f / sqrt(float(n_state) / n_head))); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled); - //struct ggml_tensor * V_trans = - // ggml_permute(ctxL, - // ggml_cpy(ctxL, - // Vcur, - // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), - // 1, 2, 0, 3); - - //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); - - struct ggml_tensor * V = - ggml_cpy(ctxL, - ggml_permute(ctxL, - ggml_reshape_3d(ctxL, - Vcur, - n_state/n_head, n_head, n_ctx), - 0, 2, 1, 3), - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head) - ); + // struct ggml_tensor * V_trans = + // ggml_permute(ctxL, + // ggml_cpy(ctxL, + // Vcur, + // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), + // 1, 2, 0, 3); + + // struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); + + struct ggml_tensor * V = ggml_cpy( + ctxL, + ggml_permute(ctxL, ggml_reshape_3d(ctxL, Vcur, n_state / n_head, n_head, n_ctx), 0, 2, 1, 3), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state / n_head, n_ctx, n_head) + ); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max); #endif struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); - cur = ggml_cpy(ctxL, - KQV_merged, - ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx)); + cur = ggml_cpy(ctxL, KQV_merged, ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx)); } // projection { - cur = ggml_mul_mat(ctxL, - layer.attn_ln_1_w, - cur); + cur = ggml_mul_mat(ctxL, layer.attn_ln_1_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.attn_ln_1_b, cur), - cur); + cur = ggml_add(ctxL, ggml_repeat(ctxL, layer.attn_ln_1_b, cur), cur); } // add the input @@ -1340,38 +1336,35 @@ static bool whisper_encode( cur = ggml_norm(ctxL, inpFF); // cur = mlp_ln_w*cur + mlp_ln_b - cur = ggml_add(ctxL, - ggml_mul(ctxL, - ggml_repeat(ctxL, layer.mlp_ln_w, cur), - cur), - ggml_repeat(ctxL, layer.mlp_ln_b, cur)); + cur = ggml_add( + ctxL, + ggml_mul(ctxL, ggml_repeat(ctxL, layer.mlp_ln_w, cur), cur), + ggml_repeat(ctxL, layer.mlp_ln_b, cur) + ); } #ifdef USE_FLASH_FF - cur = ggml_flash_ff(ctxL, - ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)), - layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); + cur = ggml_flash_ff( + ctxL, + ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)), + layer.mlp_0_w, + layer.mlp_0_b, + layer.mlp_1_w, + layer.mlp_1_b + ); #else // fully connected - cur = ggml_mul_mat(ctxL, - layer.mlp_0_w, - cur); + cur = ggml_mul_mat(ctxL, layer.mlp_0_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.mlp_0_b, cur), - cur); + cur = ggml_add(ctxL, ggml_repeat(ctxL, layer.mlp_0_b, cur), cur); // GELU activation cur = ggml_gelu(ctxL, cur); // projection - cur = ggml_mul_mat(ctxL, - layer.mlp_1_w, - cur); + cur = ggml_mul_mat(ctxL, layer.mlp_1_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.mlp_1_b, cur), - cur); + cur = ggml_add(ctxL, ggml_repeat(ctxL, layer.mlp_1_b, cur), cur); #endif } @@ -1380,22 +1373,22 @@ static bool whisper_encode( { struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; + gf.n_threads = n_threads; ggml_build_forward_expand(&gf, inpO); - ggml_graph_compute (ctxL, &gf); + ggml_graph_compute(ctxL, &gf); - //ggml_graph_print(&gf); + // ggml_graph_print(&gf); } // TODO: this is a hack to have per-layer computation graphs - need to come up with something better // input for next layer (inpO -> inpL) memcpy(inpL->data, inpO->data, ggml_nbytes(inpL)); - inpL->op = GGML_OP_NONE; + inpL->op = GGML_OP_NONE; inpL->src0 = nullptr; inpL->src1 = nullptr; - //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0); + // printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0); ggml_free(ctxL); } @@ -1407,22 +1400,22 @@ static bool whisper_encode( cur = ggml_norm(ctx0, cur); // cur = ln_f_g*cur + ln_f_b - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.e_ln_w, cur), - cur), - ggml_repeat(ctx0, model.e_ln_b, cur)); + cur = ggml_add( + ctx0, + ggml_mul(ctx0, ggml_repeat(ctx0, model.e_ln_w, cur), cur), + ggml_repeat(ctx0, model.e_ln_b, cur) + ); } // run the computation { struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; + gf.n_threads = n_threads; ggml_build_forward_expand(&gf, cur); - ggml_graph_compute (ctx0, &gf); + ggml_graph_compute(ctx0, &gf); - //ggml_graph_print(&gf); + // ggml_graph_print(&gf); } // cur @@ -1442,36 +1435,40 @@ static bool whisper_encode( // pre-compute cross-attention memory { struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; + gf.n_threads = n_threads; // TODO: hack to disconnect the encoded features from the previous graph - cur->op = GGML_OP_NONE; + cur->op = GGML_OP_NONE; cur->src0 = nullptr; cur->src1 = nullptr; for (int il = 0; il < model.hparams.n_text_layer; ++il) { auto & layer = model.layers_decoder[il]; - struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, - layer.cross_attn_k_w, - cur); + struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, layer.cross_attn_k_w, cur); - Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25))); - struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, - layer.cross_attn_v_w, - cur); + struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, layer.cross_attn_v_w, cur); - Vcross = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.cross_attn_v_b, - Vcross), - Vcross); + Vcross = ggml_add(ctx0, ggml_repeat(ctx0, layer.cross_attn_v_b, Vcross), Vcross); - //struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - //struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx)); - struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx)); + // struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, + // (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); struct + // ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, + // (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + struct ggml_tensor * k = ggml_view_1d( + ctx0, + model.memory_cross_k, + n_state * n_ctx, + (ggml_element_size(model.memory_cross_k) * n_state) * (il * n_ctx) + ); + struct ggml_tensor * v = ggml_view_1d( + ctx0, + model.memory_cross_v, + n_state * n_ctx, + (ggml_element_size(model.memory_cross_v) * n_state) * (il * n_ctx) + ); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v)); @@ -1482,7 +1479,7 @@ static bool whisper_encode( //////////////////////////////////////////////////////////////////////////// - //printf("%s: used_mem = %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0); + // printf("%s: used_mem = %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0); ggml_free(ctx0); @@ -1500,11 +1497,12 @@ static bool whisper_encode( // - n_past: number of past tokens to prefix the prompt with // static bool whisper_decode( - whisper_context & wctx, - const int n_threads, - const whisper_token * tokens, - const int n_tokens, - const int n_past) { + whisper_context & wctx, + const int n_threads, + const whisper_token * tokens, + const int n_tokens, + const int n_past +) { const auto & model = wctx.model; const auto & hparams = model.hparams; @@ -1528,7 +1526,7 @@ static bool whisper_decode( struct ggml_context * ctx0 = ggml_init(params); struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - memcpy(embd->data, tokens, N*ggml_element_size(embd)); + memcpy(embd->data, tokens, N * ggml_element_size(embd)); struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); for (int i = 0; i < N; ++i) { @@ -1537,9 +1535,7 @@ static bool whisper_decode( // token encoding + position encoding struct ggml_tensor * cur = - ggml_add(ctx0, - ggml_get_rows(ctx0, model.d_te, embd), - ggml_get_rows(ctx0, model.d_pe, position)); + ggml_add(ctx0, ggml_get_rows(ctx0, model.d_te, embd), ggml_get_rows(ctx0, model.d_pe, position)); struct ggml_tensor * inpL = cur; @@ -1551,56 +1547,52 @@ static bool whisper_decode( paramsL.mem_buffer = wctx.buf_compute_layer.data(); struct ggml_context * ctxL = ggml_init(paramsL); - struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; // norm { cur = ggml_norm(ctxL, inpL); // cur = ln_0_w*cur + ln_0_b - cur = ggml_add(ctxL, - ggml_mul(ctxL, - ggml_repeat(ctxL, layer.attn_ln_0_w, cur), - cur), - ggml_repeat(ctxL, layer.attn_ln_0_b, cur)); + cur = ggml_add( + ctxL, + ggml_mul(ctxL, ggml_repeat(ctxL, layer.attn_ln_0_w, cur), cur), + ggml_repeat(ctxL, layer.attn_ln_0_b, cur) + ); } // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctxL, - layer.attn_q_w, - cur); + struct ggml_tensor * Qcur = ggml_mul_mat(ctxL, layer.attn_q_w, cur); - Qcur = ggml_add(ctxL, - ggml_repeat(ctxL, - layer.attn_q_b, - Qcur), - Qcur); + Qcur = ggml_add(ctxL, ggml_repeat(ctxL, layer.attn_q_b, Qcur), Qcur); - Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state) / n_head, -0.25))); // note: no bias for Key - struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, - layer.attn_k_w, - cur); + struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, layer.attn_k_w, cur); - Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state) / n_head, -0.25))); - struct ggml_tensor * Vcur = ggml_mul_mat(ctxL, - layer.attn_v_w, - cur); + struct ggml_tensor * Vcur = ggml_mul_mat(ctxL, layer.attn_v_w, cur); - Vcur = ggml_add(ctxL, - ggml_repeat(ctxL, - layer.attn_v_b, - Vcur), - Vcur); + Vcur = ggml_add(ctxL, ggml_repeat(ctxL, layer.attn_v_b, Vcur), Vcur); // store key and value to memory { - struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d( + ctxL, + model.memory_k, + N * n_state, + (ggml_element_size(model.memory_k) * n_state) * (il * n_ctx + n_past) + ); + struct ggml_tensor * v = ggml_view_1d( + ctxL, + model.memory_v, + N * n_state, + (ggml_element_size(model.memory_v) * n_state) * (il * n_ctx + n_past) + ); ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v)); @@ -1608,57 +1600,79 @@ static bool whisper_decode( // ------ - struct ggml_tensor * Q = - ggml_permute(ctxL, - ggml_cpy(ctxL, - Qcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)), - 0, 2, 1, 3); - - struct ggml_tensor * K = - ggml_permute(ctxL, - ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state), - n_state/n_head, n_head, n_past + N), - 0, 2, 1, 3); + struct ggml_tensor * Q = ggml_permute( + ctxL, + ggml_cpy(ctxL, Qcur, ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state / n_head, n_head, N)), + 0, + 2, + 1, + 3 + ); + + struct ggml_tensor * K = ggml_permute( + ctxL, + ggml_reshape_3d( + ctxL, + ggml_view_1d( + ctxL, + model.memory_k, + (n_past + N) * n_state, + il * n_ctx * ggml_element_size(model.memory_k) * n_state + ), + n_state / n_head, + n_head, + n_past + N + ), + 0, + 2, + 1, + 3 + ); // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); - //struct ggml_tensor * KQ_scaled = - // ggml_scale(ctxL, - // KQ, - // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) - // ); + // struct ggml_tensor * KQ_scaled = + // ggml_scale(ctxL, + // KQ, + // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) + // ); struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked); - struct ggml_tensor * V_trans = - ggml_permute(ctxL, - ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state), - n_state/n_head, n_head, n_past + N), - 1, 2, 0, 3); + struct ggml_tensor * V_trans = ggml_permute( + ctxL, + ggml_reshape_3d( + ctxL, + ggml_view_1d( + ctxL, + model.memory_v, + (n_past + N) * n_state, + il * n_ctx * ggml_element_size(model.memory_v) * n_state + ), + n_state / n_head, + n_head, + n_past + N + ), + 1, + 2, + 0, + 3 + ); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); - cur = ggml_cpy(ctxL, - KQV_merged, - ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N)); + cur = ggml_cpy(ctxL, KQV_merged, ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N)); } { - cur = ggml_mul_mat(ctxL, - layer.attn_ln_1_w, - cur); + cur = ggml_mul_mat(ctxL, layer.attn_ln_1_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.attn_ln_1_b, cur), - cur); + cur = ggml_add(ctxL, ggml_repeat(ctxL, layer.attn_ln_1_b, cur), cur); } // add the input @@ -1669,60 +1683,72 @@ static bool whisper_decode( cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here // cur = ln_0_w*cur + ln_0_b - cur = ggml_add(ctxL, - ggml_mul(ctxL, - ggml_repeat(ctxL, layer.cross_attn_ln_0_w, cur), - cur), - ggml_repeat(ctxL, layer.cross_attn_ln_0_b, cur)); + cur = ggml_add( + ctxL, + ggml_mul(ctxL, ggml_repeat(ctxL, layer.cross_attn_ln_0_w, cur), cur), + ggml_repeat(ctxL, layer.cross_attn_ln_0_b, cur) + ); } // cross-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctxL, - layer.cross_attn_q_w, - cur); + struct ggml_tensor * Qcur = ggml_mul_mat(ctxL, layer.cross_attn_q_w, cur); - Qcur = ggml_add(ctxL, - ggml_repeat(ctxL, - layer.cross_attn_q_b, - Qcur), - Qcur); + Qcur = ggml_add(ctxL, ggml_repeat(ctxL, layer.cross_attn_q_b, Qcur), Qcur); - Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state) / n_head, -0.25))); // Kcross is already scaled - struct ggml_tensor * Kcross = - ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state), - n_state/n_head, n_head, M); - - struct ggml_tensor * Vcross = - ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state), - n_state/n_head, n_head, M); + struct ggml_tensor * Kcross = ggml_reshape_3d( + ctxL, + ggml_view_1d( + ctxL, + model.memory_cross_k, + M * n_state, + il * M * ggml_element_size(model.memory_cross_k) * n_state + ), + n_state / n_head, + n_head, + M + ); + + struct ggml_tensor * Vcross = ggml_reshape_3d( + ctxL, + ggml_view_1d( + ctxL, + model.memory_cross_v, + M * n_state, + il * M * ggml_element_size(model.memory_cross_v) * n_state + ), + n_state / n_head, + n_head, + M + ); // ------ - struct ggml_tensor * Q = - ggml_permute(ctxL, - ggml_cpy(ctxL, - Qcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)), - 0, 2, 1, 3); + struct ggml_tensor * Q = ggml_permute( + ctxL, + ggml_cpy(ctxL, Qcur, ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state / n_head, n_head, N)), + 0, + 2, + 1, + 3 + ); struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3); // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); - //struct ggml_tensor * KQ_scaled = - // ggml_scale(ctxL, - // KQ, - // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) - // ); + // struct ggml_tensor * KQ_scaled = + // ggml_scale(ctxL, + // KQ, + // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) + // ); // no masking for cross-attention - //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past); + // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ); @@ -1733,20 +1759,14 @@ static bool whisper_decode( struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); // cur = KQV_merged.contiguous().view(n_state, N) - cur = ggml_cpy(ctxL, - KQV_merged, - ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N)); + cur = ggml_cpy(ctxL, KQV_merged, ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N)); } // projection { - cur = ggml_mul_mat(ctxL, - layer.cross_attn_ln_1_w, - cur); + cur = ggml_mul_mat(ctxL, layer.cross_attn_ln_1_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.cross_attn_ln_1_b, cur), - cur); + cur = ggml_add(ctxL, ggml_repeat(ctxL, layer.cross_attn_ln_1_b, cur), cur); } // add the input @@ -1761,33 +1781,25 @@ static bool whisper_decode( cur = ggml_norm(ctxL, inpFF); // cur = mlp_ln_w*cur + mlp_ln_b - cur = ggml_add(ctxL, - ggml_mul(ctxL, - ggml_repeat(ctxL, layer.mlp_ln_w, cur), - cur), - ggml_repeat(ctxL, layer.mlp_ln_b, cur)); + cur = ggml_add( + ctxL, + ggml_mul(ctxL, ggml_repeat(ctxL, layer.mlp_ln_w, cur), cur), + ggml_repeat(ctxL, layer.mlp_ln_b, cur) + ); } // fully connected - cur = ggml_mul_mat(ctxL, - layer.mlp_0_w, - cur); + cur = ggml_mul_mat(ctxL, layer.mlp_0_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.mlp_0_b, cur), - cur); + cur = ggml_add(ctxL, ggml_repeat(ctxL, layer.mlp_0_b, cur), cur); // GELU activation cur = ggml_gelu(ctxL, cur); // projection - cur = ggml_mul_mat(ctxL, - layer.mlp_1_w, - cur); + cur = ggml_mul_mat(ctxL, layer.mlp_1_w, cur); - cur = ggml_add(ctxL, - ggml_repeat(ctxL, layer.mlp_1_b, cur), - cur); + cur = ggml_add(ctxL, ggml_repeat(ctxL, layer.mlp_1_b, cur), cur); } // output from this layer @@ -1795,20 +1807,20 @@ static bool whisper_decode( { ggml_build_forward_expand(&gf, inpO); - ggml_graph_compute (ctxL, &gf); + ggml_graph_compute(ctxL, &gf); - //ggml_graph_print(&gf); + // ggml_graph_print(&gf); } // TODO: this is a hack to have per-layer computation graphs - need to come up with something better // input for next layer (inpO -> inpL) memcpy(inpL->data, inpO->data, ggml_nbytes(inpL)); - inpL->op = GGML_OP_NONE; + inpL->op = GGML_OP_NONE; inpL->src0 = nullptr; inpL->src1 = nullptr; if (N > 1) { - //printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0); + // printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0); } ggml_free(ctxL); @@ -1820,11 +1832,11 @@ static bool whisper_decode( { cur = ggml_norm(ctx0, cur); - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.d_ln_w, cur), - cur), - ggml_repeat(ctx0, model.d_ln_b, cur)); + cur = ggml_add( + ctx0, + ggml_mul(ctx0, ggml_repeat(ctx0, model.d_ln_w, cur), cur), + ggml_repeat(ctx0, model.d_ln_b, cur) + ); } struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); @@ -1836,22 +1848,22 @@ static bool whisper_decode( // run the computation { struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; + gf.n_threads = n_threads; ggml_build_forward_expand(&gf, cur); - ggml_graph_compute (ctx0, &gf); + ggml_graph_compute(ctx0, &gf); } - logits_out.resize(N*n_vocab); - memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); + logits_out.resize(N * n_vocab); + memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float) * N * n_vocab); - probs_out.resize(N*n_vocab); - memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab); + probs_out.resize(N * n_vocab); + memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float) * N * n_vocab); if (N > 1) { - //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N; - //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token); - //printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx); + // const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N; + // printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token); + // printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx); } ggml_free(ctx0); @@ -1861,10 +1873,11 @@ static bool whisper_decode( // the most basic sampling scheme - select the top token static whisper_token_data whisper_sample_best( - whisper_vocab & vocab, - const float * probs, - bool force_timestamp, - bool is_initial) { + whisper_vocab & vocab, + const float * probs, + bool force_timestamp, + bool is_initial +) { whisper_token_data result = { 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, }; @@ -1879,7 +1892,7 @@ static whisper_token_data whisper_sample_best( } { - double sum_ts = 0.0; + double sum_ts = 0.0; double max_ts = -1.0; double max_tx = -1.0; @@ -1891,31 +1904,33 @@ static whisper_token_data whisper_sample_best( const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits; // the initial timestamp cannot be larger than 100 - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 + // ref: + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 if (is_initial) { - for (int i = i0; i < n_logits; ++ i) { + for (int i = i0; i < n_logits; ++i) { probs_id[i].first = -INFINITY; } } for (int i = vocab.token_beg; i < i1; i++) { sum_ts += probs_id[i].first; - if (probs_id[i].first > max_ts) { - max_ts = probs_id[i].first; + if (probs_id[i].first > max_ts) { + max_ts = probs_id[i].first; result.tid = probs_id[i].second; } } - // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a - // timestamp token + // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample + // a timestamp token if (sum_ts > max_tx || force_timestamp) { - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 + // ref: + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 for (int i = 0; i < vocab.token_beg; i++) { probs_id[i].first = -INFINITY; } } - result.pt = max_ts/(sum_ts + 1e-10); + result.pt = max_ts / (sum_ts + 1e-10); result.ptsum = sum_ts; } @@ -1923,24 +1938,26 @@ static whisper_token_data whisper_sample_best( const int top_k = 4; std::partial_sort( - probs_id.begin(), - probs_id.begin() + top_k, probs_id.end(), - [](const std::pair & a, const std::pair & b) { - return a.first > b.first; - }); + probs_id.begin(), + probs_id.begin() + top_k, + probs_id.end(), + [](const std::pair & a, const std::pair & b) { + return a.first > b.first; + } + ); probs_id.resize(top_k); - //printf("\n"); - //for (int i = 0; i < (int) probs_id.size(); i++) { - // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second); - //} + // printf("\n"); + // for (int i = 0; i < (int) probs_id.size(); i++) { + // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, + // probs_id[i].second); + // } int res = 0; - while ((probs_id[res].second == vocab.token_sot || - probs_id[res].second == vocab.token_solm || + while ((probs_id[res].second == vocab.token_sot || probs_id[res].second == vocab.token_solm || probs_id[res].second == vocab.token_not) && - res < (int) probs_id.size() - 1) { + res < (int) probs_id.size() - 1) { res++; } @@ -1954,12 +1971,12 @@ static whisper_token_data whisper_sample_best( // 6000 -> 01:00.000 static std::string to_timestamp(int64_t t, bool comma = false) { int64_t msec = t * 10; - int64_t hr = msec / (1000 * 60 * 60); - msec = msec - hr * (1000 * 60 * 60); - int64_t min = msec / (1000 * 60); - msec = msec - min * (1000 * 60); - int64_t sec = msec / 1000; - msec = msec - sec * 1000; + int64_t hr = msec / (1000 * 60 * 60); + msec = msec - hr * (1000 * 60 * 60); + int64_t min = msec / (1000 * 60); + msec = msec - min * (1000 * 60); + int64_t sec = msec / 1000; + msec = msec - sec * 1000; char buf[32]; snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); @@ -1973,20 +1990,20 @@ static std::string to_timestamp(int64_t t, bool comma = false) { static void dft(const std::vector & in, std::vector & out) { int N = in.size(); - out.resize(N*2); + out.resize(N * 2); for (int k = 0; k < N; k++) { float re = 0; float im = 0; for (int n = 0; n < N; n++) { - float angle = 2*M_PI*k*n/N; - re += in[n]*cos(angle); - im -= in[n]*sin(angle); + float angle = 2 * M_PI * k * n / N; + re += in[n] * cos(angle); + im -= in[n] * sin(angle); } - out[k*2 + 0] = re; - out[k*2 + 1] = im; + out[k * 2 + 0] = re; + out[k * 2 + 1] = im; } } @@ -1995,7 +2012,7 @@ static void dft(const std::vector & in, std::vector & out) { // input is real-valued // output is complex-valued static void fft(const std::vector & in, std::vector & out) { - out.resize(in.size()*2); + out.resize(in.size() * 2); int N = in.size(); @@ -2005,7 +2022,7 @@ static void fft(const std::vector & in, std::vector & out) { return; } - if (N%2 == 1) { + if (N % 2 == 1) { dft(in, out); return; } @@ -2013,8 +2030,8 @@ static void fft(const std::vector & in, std::vector & out) { std::vector even; std::vector odd; - even.reserve(N/2); - odd.reserve(N/2); + even.reserve(N / 2); + odd.reserve(N / 2); for (int i = 0; i < N; i++) { if (i % 2 == 0) { @@ -2030,20 +2047,20 @@ static void fft(const std::vector & in, std::vector & out) { fft(even, even_fft); fft(odd, odd_fft); - for (int k = 0; k < N/2; k++) { - float theta = 2*M_PI*k/N; + for (int k = 0; k < N / 2; k++) { + float theta = 2 * M_PI * k / N; float re = cos(theta); float im = -sin(theta); - float re_odd = odd_fft[2*k + 0]; - float im_odd = odd_fft[2*k + 1]; + float re_odd = odd_fft[2 * k + 0]; + float im_odd = odd_fft[2 * k + 1]; - out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; - out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + out[2 * k + 0] = even_fft[2 * k + 0] + re * re_odd - im * im_odd; + out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd; - out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; - out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + out[2 * (k + N / 2) + 0] = even_fft[2 * k + 0] - re * re_odd + im * im_odd; + out[2 * (k + N / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd; } } @@ -2058,90 +2075,95 @@ static bool log_mel_spectrogram( const int n_threads, const whisper_filters & filters, const bool speed_up, - whisper_mel & mel) { + whisper_mel & mel +) { // Hanning window std::vector hann; hann.resize(fft_size); for (int i = 0; i < fft_size; i++) { - hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size))); + hann[i] = 0.5 * (1.0 - cos((2.0 * M_PI * i) / (fft_size))); } mel.n_mel = n_mel; - mel.n_len = (n_samples)/fft_step; - mel.data.resize(mel.n_mel*mel.n_len); + mel.n_len = (n_samples) / fft_step; + mel.data.resize(mel.n_mel * mel.n_len); - const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2); + const int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2); - //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len); - //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate); + // printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len); + // printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate); std::vector workers(n_threads); for (int iw = 0; iw < n_threads; ++iw) { - workers[iw] = std::thread([&](int ith) { - std::vector fft_in; - fft_in.resize(fft_size); - for (int i = 0; i < fft_size; i++) { - fft_in[i] = 0.0; - } + workers[iw] = std::thread( + [&](int ith) { + std::vector fft_in; + fft_in.resize(fft_size); + for (int i = 0; i < fft_size; i++) { + fft_in[i] = 0.0; + } - std::vector fft_out; - fft_out.resize(2*fft_size); + std::vector fft_out; + fft_out.resize(2 * fft_size); - for (int i = ith; i < mel.n_len; i += n_threads) { - const int offset = i*fft_step; + for (int i = ith; i < mel.n_len; i += n_threads) { + const int offset = i * fft_step; - // apply Hanning window - for (int j = 0; j < fft_size; j++) { - if (offset + j < n_samples) { - fft_in[j] = hann[j]*samples[offset + j]; - } else { - fft_in[j] = 0.0; + // apply Hanning window + for (int j = 0; j < fft_size; j++) { + if (offset + j < n_samples) { + fft_in[j] = hann[j] * samples[offset + j]; + } else { + fft_in[j] = 0.0; + } } - } - // FFT -> mag^2 - fft(fft_in, fft_out); + // FFT -> mag^2 + fft(fft_in, fft_out); - for (int j = 0; j < fft_size; j++) { - fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]); - } - for (int j = 1; j < fft_size/2; j++) { - //if (i == 0) { - // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]); - //} - fft_out[j] += fft_out[fft_size - j]; - } - if (i == 0) { - //for (int j = 0; j < fft_size; j++) { - // printf("%d: %e\n", j, fft_out[j]); - //} - } + for (int j = 0; j < fft_size; j++) { + fft_out[j] = + (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); + } + for (int j = 1; j < fft_size / 2; j++) { + // if (i == 0) { + // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]); + // } + fft_out[j] += fft_out[fft_size - j]; + } + if (i == 0) { + // for (int j = 0; j < fft_size; j++) { + // printf("%d: %e\n", j, fft_out[j]); + // } + } - if (speed_up) { - // scale down in the frequency domain results in a speed up in the time domain - for (int j = 0; j < n_fft; j++) { - fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]); + if (speed_up) { + // scale down in the frequency domain results in a speed up in the time domain + for (int j = 0; j < n_fft; j++) { + fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]); + } } - } - // mel spectrogram - for (int j = 0; j < mel.n_mel; j++) { - double sum = 0.0; + // mel spectrogram + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; - for (int k = 0; k < n_fft; k++) { - sum += fft_out[k]*filters.data[j*n_fft + k]; - } - if (sum < 1e-10) { - sum = 1e-10; - } + for (int k = 0; k < n_fft; k++) { + sum += fft_out[k] * filters.data[j * n_fft + k]; + } + if (sum < 1e-10) { + sum = 1e-10; + } - sum = log10(sum); + sum = log10(sum); - mel.data[j*mel.n_len + i] = sum; + mel.data[j * mel.n_len + i] = sum; + } } - } - }, iw); + }, + iw + ); } for (int iw = 0; iw < n_threads; ++iw) { @@ -2150,21 +2172,21 @@ static bool log_mel_spectrogram( // clamping and normalization double mmax = -1e20; - for (int i = 0; i < mel.n_mel*mel.n_len; i++) { + for (int i = 0; i < mel.n_mel * mel.n_len; i++) { if (mel.data[i] > mmax) { mmax = mel.data[i]; } } - //printf("%s: max = %f\n", __func__, mmax); + // printf("%s: max = %f\n", __func__, mmax); mmax -= 8.0; - for (int i = 0; i < mel.n_mel*mel.n_len; i++) { + for (int i = 0; i < mel.n_mel * mel.n_len; i++) { if (mel.data[i] < mmax) { mel.data[i] = mmax; } - mel.data[i] = (mel.data[i] + 4.0)/4.0; + mel.data[i] = (mel.data[i] + 4.0) / 4.0; } return true; @@ -2186,7 +2208,8 @@ static std::vector tokenize(const whisper_vocab & vocab, cons // first split the text into words { std::string str = text; - std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; + std::string pat = + R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; std::regex re(pat); std::smatch m; @@ -2202,14 +2225,16 @@ static std::vector tokenize(const whisper_vocab & vocab, cons // find the longest tokens that form the words: std::vector tokens; for (const auto & word : words) { - if (word.empty()) continue; + if (word.empty()) { + continue; + } int i = 0; int n = word.size(); while (i < n) { int j = n; while (j > i) { - auto it = vocab.token_to_id.find(word.substr(i, j-i)); + auto it = vocab.token_to_id.find(word.substr(i, j - i)); if (it != vocab.token_to_id.end()) { tokens.push_back(it->second); i = j; @@ -2277,7 +2302,18 @@ void whisper_free(struct whisper_context * ctx) { int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { const int64_t t_start_us = ggml_time_us(); - if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) { + if (!log_mel_spectrogram( + samples, + n_samples, + WHISPER_SAMPLE_RATE, + WHISPER_N_FFT, + WHISPER_HOP_LENGTH, + WHISPER_N_MEL, + n_threads, + ctx->model.filters, + false, + ctx->mel + )) { fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -2288,10 +2324,26 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int } // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 -int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { +int whisper_pcm_to_mel_phase_vocoder( + struct whisper_context * ctx, + const float * samples, + int n_samples, + int n_threads +) { const int64_t t_start_us = ggml_time_us(); - if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) { + if (!log_mel_spectrogram( + samples, + n_samples, + WHISPER_SAMPLE_RATE, + 2 * WHISPER_N_FFT, + 2 * WHISPER_HOP_LENGTH, + WHISPER_N_MEL, + n_threads, + ctx->model.filters, + true, + ctx->mel + )) { fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -2301,11 +2353,7 @@ int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * return 0; } -int whisper_set_mel( - struct whisper_context * ctx, - const float * data, - int n_len, - int n_mel) { +int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel) { if (n_mel != WHISPER_N_MEL) { fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL); return -1; @@ -2314,8 +2362,8 @@ int whisper_set_mel( ctx->mel.n_len = n_len; ctx->mel.n_mel = n_mel; - ctx->mel.data.resize(n_len*n_mel); - memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float)); + ctx->mel.data.resize(n_len * n_mel); + memcpy(ctx->mel.data.data(), data, n_len * n_mel * sizeof(float)); return 0; } @@ -2333,7 +2381,13 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { return 0; } -int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { +int whisper_decode( + struct whisper_context * ctx, + const whisper_token * tokens, + int n_tokens, + int n_past, + int n_threads +) { const int64_t t_start_us = ggml_time_us(); if (!whisper_decode(*ctx, n_threads, tokens, n_tokens, n_past)) { @@ -2349,7 +2403,8 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { const int64_t t_start_sample_us = ggml_time_us(); - const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false); + const auto res = + whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; @@ -2359,7 +2414,8 @@ struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) { const int64_t t_start_sample_us = ggml_time_us(); - const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial); + const auto res = + whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; @@ -2416,12 +2472,8 @@ const char * whisper_lang_str(int id) { return nullptr; } -int whisper_lang_auto_detect( - struct whisper_context * ctx, - int offset_ms, - int n_threads, - float * lang_probs) { - const int seek = offset_ms/10; +int whisper_lang_auto_detect(struct whisper_context * ctx, int offset_ms, int n_threads, float * lang_probs) { + const int seek = offset_ms / 10; if (seek < 0) { fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms); @@ -2429,7 +2481,13 @@ int whisper_lang_auto_detect( } if (seek >= ctx->mel.n_len) { - fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10); + fprintf( + stderr, + "%s: offset %dms is past the end of the audio (%dms)\n", + __func__, + offset_ms, + ctx->mel.n_len * 10 + ); return -2; } @@ -2478,7 +2536,8 @@ int whisper_lang_auto_detect( lang_probs[probs_id[i].second] = probs_id[i].first; } - //printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first); + // printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), + // probs_id[i].first); } } @@ -2553,12 +2612,24 @@ void whisper_print_timings(struct whisper_context * ctx) { const int64_t t_end_us = ggml_time_us(); fprintf(stderr, "\n"); - fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f); - fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f); - fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, ctx->t_sample_us/1000.0f); - fprintf(stderr, "%s: encode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_encode_us/1000.0f, ctx->t_encode_us/1000.0f/ctx->model.hparams.n_audio_layer); - fprintf(stderr, "%s: decode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_decode_us/1000.0f, ctx->t_decode_us/1000.0f/ctx->model.hparams.n_text_layer); - fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); + fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us / 1000.0f); + fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, ctx->t_sample_us / 1000.0f); + fprintf( + stderr, + "%s: encode time = %8.2f ms / %.2f ms per layer\n", + __func__, + ctx->t_encode_us / 1000.0f, + ctx->t_encode_us / 1000.0f / ctx->model.hparams.n_audio_layer + ); + fprintf( + stderr, + "%s: decode time = %8.2f ms / %.2f ms per layer\n", + __func__, + ctx->t_decode_us / 1000.0f, + ctx->t_decode_us / 1000.0f / ctx->model.hparams.n_text_layer + ); + fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us) / 1000.0f); } void whisper_reset_timings(struct whisper_context * ctx) { @@ -2570,17 +2641,17 @@ void whisper_reset_timings(struct whisper_context * ctx) { const char * whisper_print_system_info(void) { static std::string s; - s = ""; - s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; - s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; - s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; - s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; - s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; - s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; - s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; - s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s = ""; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; + s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; - s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; return s.c_str(); } @@ -2591,104 +2662,108 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str struct whisper_full_params result; switch (strategy) { - case WHISPER_SAMPLING_GREEDY: - { - result = { - /*.strategy =*/ WHISPER_SAMPLING_GREEDY, - - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, - - /*.translate =*/ false, - /*.no_context =*/ false, - /*.single_segment =*/ false, - /*.print_special =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, - - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, - /*.max_tokens =*/ 0, - - /*.speed_up =*/ false, - /*.audio_ctx =*/ 0, - - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, - - /*.language =*/ "en", - - /*.greedy =*/ { - /*.n_past =*/ 0, - }, - - /*.beam_search =*/ { - /*.n_past =*/ -1, - /*.beam_width =*/ -1, - /*.n_best =*/ -1, - }, - - /*.new_segment_callback =*/ nullptr, - /*.new_segment_callback_user_data =*/ nullptr, - - /*.encoder_begin_callback =*/ nullptr, - /*.encoder_begin_callback_user_data =*/ nullptr, - }; - } break; - case WHISPER_SAMPLING_BEAM_SEARCH: - { - result = { - /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH, - - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, - - /*.translate =*/ false, - /*.no_context =*/ false, - /*.single_segment =*/ false, - /*.print_special =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, - - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, - /*.max_tokens =*/ 0, - - /*.speed_up =*/ false, - /*.audio_ctx =*/ 0, - - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, - - /*.language =*/ "en", - - /*.greedy =*/ { - /*.n_past =*/ -1, - }, - - /*.beam_search =*/ { - /*.n_past =*/ 0, - /*.beam_width =*/ 10, - /*.n_best =*/ 5, - }, - - /*.new_segment_callback =*/ nullptr, - /*.new_segment_callback_user_data =*/ nullptr, - - /*.encoder_begin_callback =*/ nullptr, - /*.encoder_begin_callback_user_data =*/ nullptr, - }; - } break; + case WHISPER_SAMPLING_GREEDY: { + result = { + /*.strategy =*/WHISPER_SAMPLING_GREEDY, + + /*.n_threads =*/std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/16384, + /*.offset_ms =*/0, + /*.duration_ms =*/0, + + /*.translate =*/false, + /*.no_context =*/false, + /*.single_segment =*/false, + /*.print_special =*/false, + /*.print_progress =*/true, + /*.print_realtime =*/false, + /*.print_timestamps =*/true, + + /*.token_timestamps =*/false, + /*.thold_pt =*/0.01f, + /*.thold_ptsum =*/0.01f, + /*.max_len =*/0, + /*.max_tokens =*/0, + + /*.speed_up =*/false, + /*.audio_ctx =*/0, + + /*.prompt_tokens =*/nullptr, + /*.prompt_n_tokens =*/0, + + /*.language =*/"en", + + /*.greedy =*/ + { + /*.n_past =*/0, + }, + + /*.beam_search =*/ + { + /*.n_past =*/-1, + /*.beam_width =*/-1, + /*.n_best =*/-1, + }, + + /*.new_segment_callback =*/ + nullptr, + /*.new_segment_callback_user_data =*/nullptr, + + /*.encoder_begin_callback =*/nullptr, + /*.encoder_begin_callback_user_data =*/nullptr, + }; + } break; + case WHISPER_SAMPLING_BEAM_SEARCH: { + result = { + /*.strategy =*/WHISPER_SAMPLING_BEAM_SEARCH, + + /*.n_threads =*/std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/16384, + /*.offset_ms =*/0, + /*.duration_ms =*/0, + + /*.translate =*/false, + /*.no_context =*/false, + /*.single_segment =*/false, + /*.print_special =*/false, + /*.print_progress =*/true, + /*.print_realtime =*/false, + /*.print_timestamps =*/true, + + /*.token_timestamps =*/false, + /*.thold_pt =*/0.01f, + /*.thold_ptsum =*/0.01f, + /*.max_len =*/0, + /*.max_tokens =*/0, + + /*.speed_up =*/false, + /*.audio_ctx =*/0, + + /*.prompt_tokens =*/nullptr, + /*.prompt_n_tokens =*/0, + + /*.language =*/"en", + + /*.greedy =*/ + { + /*.n_past =*/-1, + }, + + /*.beam_search =*/ + { + /*.n_past =*/0, + /*.beam_width =*/10, + /*.n_best =*/5, + }, + + /*.new_segment_callback =*/ + nullptr, + /*.new_segment_callback_user_data =*/nullptr, + + /*.encoder_begin_callback =*/nullptr, + /*.encoder_begin_callback_user_data =*/nullptr, + }; + } break; } return result; @@ -2697,10 +2772,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str // forward declarations static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); static void whisper_exp_compute_token_level_timestamps( - struct whisper_context * ctx, - int i_segment, - float thold_pt, - float thold_ptsum); + struct whisper_context * ctx, + int i_segment, + float thold_pt, + float thold_ptsum +); // wrap the last segment to max_len characters // returns the number of new segments @@ -2725,7 +2801,7 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { if (acc + cur > max_len && i > 0) { // split here ctx->result_all.back().text = std::move(text); - ctx->result_all.back().t1 = token.t0; + ctx->result_all.back().t1 = token.t0; ctx->result_all.back().tokens.resize(i); ctx->result_all.push_back({}); @@ -2733,16 +2809,14 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { ctx->result_all.back().t1 = segment.t1; // add tokens [i, end] to the new segment - ctx->result_all.back().tokens.insert( - ctx->result_all.back().tokens.end(), - segment.tokens.begin() + i, - segment.tokens.end()); + ctx->result_all.back() + .tokens.insert(ctx->result_all.back().tokens.end(), segment.tokens.begin() + i, segment.tokens.end()); - acc = 0; + acc = 0; text = ""; segment = ctx->result_all.back(); - i = -1; + i = -1; res++; } else { @@ -2757,10 +2831,11 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { } int whisper_full( - struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples) { + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples +) { // clear old results auto & result_all = ctx->result_all; @@ -2791,18 +2866,24 @@ int whisper_full( params.language = whisper_lang_str(lang_id); - fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); + fprintf( + stderr, + "%s: auto-detected language: %s (p = %f)\n", + __func__, + params.language, + probs[whisper_lang_id(params.language)] + ); } if (params.token_timestamps) { - ctx->t_beg = 0; - ctx->t_last = 0; + ctx->t_beg = 0; + ctx->t_last = 0; ctx->tid_last = 0; - ctx->energy = get_signal_energy(samples, n_samples, 32); + ctx->energy = get_signal_energy(samples, n_samples, 32); } - const int seek_start = params.offset_ms/10; - const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10); + const int seek_start = params.offset_ms / 10; + const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms / 10); // if length of spectrogram is less than 1s (100 samples), then return // basically don't process anything that is less than 1s @@ -2828,7 +2909,13 @@ int whisper_full( // overwrite audio_ctx, max allowed is hparams.n_audio_ctx if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { - fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); + fprintf( + stderr, + "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", + __func__, + params.audio_ctx, + whisper_n_audio_ctx(ctx) + ); return -4; } ctx->exp_n_audio_ctx = params.audio_ctx; @@ -2857,7 +2944,7 @@ int whisper_full( // main loop int seek = seek_start; while (true) { - const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); + const int progress_cur = (100 * (seek - seek_start)) / (seek_end - seek_start); while (progress_cur >= progress_prev + progress_step) { progress_prev += progress_step; if (params.print_progress) { @@ -2894,7 +2981,8 @@ int whisper_full( // if we have already generated some text, use it as a prompt to condition the next generation if (!prompt_past.empty()) { - int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); + int n_take = + std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx) / 2), int(prompt_past.size())); prompt = { whisper_token_prev(ctx) }; prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); @@ -2905,14 +2993,14 @@ int whisper_full( prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); - int seek_delta = 100*WHISPER_CHUNK_SIZE; + int seek_delta = 100 * WHISPER_CHUNK_SIZE; // print the prompt - //printf("\n\n"); - //for (int i = 0; i < prompt.size(); i++) { + // printf("\n\n"); + // for (int i = 0; i < prompt.size(); i++) { // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); //} - //printf("\n\n"); + // printf("\n\n"); // the accumulated transcription in the current interation int result_len = 0; @@ -2921,7 +3009,7 @@ int whisper_full( bool failed = false; bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? - for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { + for (int i = 0, n_max = whisper_n_text_ctx(ctx) / 2 - 4; i < n_max; ++i) { if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { fprintf(stderr, "%s: failed to decode\n", __func__); return -5; @@ -2942,7 +3030,7 @@ int whisper_full( // timestamp token - update sliding window if (token.id > whisper_token_beg(ctx)) { - const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + const int seek_delta_new = 2 * (token.id - whisper_token_beg(ctx)); // do not allow to go back in time if (has_ts && seek_delta > seek_delta_new && result_len < i) { @@ -2951,7 +3039,7 @@ int whisper_full( seek_delta = seek_delta_new; result_len = i + 1; - has_ts = true; + has_ts = true; } // add it to the context @@ -2960,14 +3048,15 @@ int whisper_full( //{ // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; - // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); + // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, + // ctx->vocab.id_to_token[token.id].c_str()); //} // end of segment if (token.id == whisper_token_eot(ctx) || // end of text token (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached - ) { + ) { if (result_len == 0) { if (seek + seek_delta + 100 >= seek_end) { result_len = i + 1; @@ -2979,7 +3068,7 @@ int whisper_full( if (params.single_segment) { result_len = i + 1; - seek_delta = 100*WHISPER_CHUNK_SIZE; + seek_delta = 100 * WHISPER_CHUNK_SIZE; } break; @@ -2987,7 +3076,7 @@ int whisper_full( // TESTS: if no tensors are loaded, it means we are running tests if (ctx->model.n_loaded == 0) { - seek_delta = 100*WHISPER_CHUNK_SIZE; + seek_delta = 100 * WHISPER_CHUNK_SIZE; break; } } @@ -2995,7 +3084,7 @@ int whisper_full( // sometimes, the decoding can get stuck in a repetition loop // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance // the sliding window by 1 second - if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { + if (i == n_max - 1 && (result_len == 0 || seek_delta < 100 * WHISPER_CHUNK_SIZE / 2)) { failed = true; break; } @@ -3022,29 +3111,34 @@ int whisper_full( // store the text from this iteration if (!tokens_cur.empty()) { - int i0 = 0; - auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); + int i0 = 0; + auto t0 = seek + 2 * (tokens_cur.front().tid - whisper_token_beg(ctx)); std::string text; for (int i = 0; i < (int) tokens_cur.size(); i++) { - //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, - // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, - // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + // printf("%s: %18s %6.3f %18s %6.3f\n", __func__, + // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, + // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { } else { text += whisper_token_to_str(ctx, tokens_cur[i].id); } if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { - const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); + const auto t1 = seek + 2 * (tokens_cur[i].tid - whisper_token_beg(ctx)); if (!text.empty()) { - const auto tt0 = params.speed_up ? 2*t0 : t0; - const auto tt1 = params.speed_up ? 2*t1 : t1; + const auto tt0 = params.speed_up ? 2 * t0 : t0; + const auto tt1 = params.speed_up ? 2 * t1 : t1; if (params.print_realtime) { if (params.print_timestamps) { - printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + printf( + "[%s --> %s] %s\n", + to_timestamp(tt0).c_str(), + to_timestamp(tt1).c_str(), + text.c_str() + ); } else { printf("%s", text.c_str()); fflush(stdout); @@ -3060,7 +3154,11 @@ int whisper_full( if (params.token_timestamps) { whisper_exp_compute_token_level_timestamps( - ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + ctx, + result_all.size() - 1, + params.thold_pt, + params.thold_ptsum + ); if (params.max_len > 0) { n_new = whisper_wrap_segment(ctx, params.max_len); @@ -3083,8 +3181,8 @@ int whisper_full( if (!text.empty()) { const auto t1 = seek + seek_delta; - const auto tt0 = params.speed_up ? 2*t0 : t0; - const auto tt1 = params.speed_up ? 2*t1 : t1; + const auto tt0 = params.speed_up ? 2 * t0 : t0; + const auto tt1 = params.speed_up ? 2 * t1 : t1; if (params.print_realtime) { if (params.print_timestamps) { @@ -3104,7 +3202,11 @@ int whisper_full( if (params.token_timestamps) { whisper_exp_compute_token_level_timestamps( - ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + ctx, + result_all.size() - 1, + params.thold_pt, + params.thold_ptsum + ); if (params.max_len > 0) { n_new = whisper_wrap_segment(ctx, params.max_len); @@ -3123,11 +3225,12 @@ int whisper_full( } int whisper_full_parallel( - struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples, - int n_processors) { + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + int n_processors +) { if (n_processors == 1) { return whisper_full(ctx, params, samples, n_samples); } @@ -3167,8 +3270,8 @@ int whisper_full_parallel( // key/value memory for the self-attention layer { - const int n_mem = n_text_layer*n_text_ctx; - const int n_elements = n_text_state*n_mem; + const int n_mem = n_text_layer * n_text_ctx; + const int n_elements = n_text_state * n_mem; model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); @@ -3178,8 +3281,8 @@ int whisper_full_parallel( { const int n_audio_ctx = hparams.n_audio_ctx; - const int n_mem = n_text_layer*n_audio_ctx; - const int n_elements = n_text_state*n_mem; + const int n_mem = n_text_layer * n_audio_ctx; + const int n_elements = n_text_state * n_mem; model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); @@ -3187,24 +3290,24 @@ int whisper_full_parallel( } } - const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000; - const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; + const int offset_samples = (WHISPER_SAMPLE_RATE * params.offset_ms) / 1000; + const int n_samples_per_processor = (n_samples - offset_samples) / n_processors; // the calling thread will process the first chunk // while the other threads will process the remaining chunks std::vector workers(n_processors - 1); for (int i = 0; i < n_processors - 1; ++i) { - const int start_samples = offset_samples + (i + 1)*n_samples_per_processor; + const int start_samples = offset_samples + (i + 1) * n_samples_per_processor; const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor; auto params_cur = params; - params_cur.offset_ms = 0; + params_cur.offset_ms = 0; params_cur.print_progress = false; params_cur.print_realtime = false; - params_cur.new_segment_callback = nullptr; + params_cur.new_segment_callback = nullptr; params_cur.new_segment_callback_user_data = nullptr; workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur); @@ -3220,7 +3323,7 @@ int whisper_full_parallel( workers[i].join(); } - const int64_t offset_t = (int64_t) params.offset_ms/10.0; + const int64_t offset_t = (int64_t) params.offset_ms / 10.0; // combine results into ctx->result_all for (int i = 0; i < n_processors - 1; ++i) { @@ -3228,8 +3331,8 @@ int whisper_full_parallel( for (int j = 0; j < (int) results_i.size(); ++j) { // correct the segment timestamp taking into account the offset - results_i[j].t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; - results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; + results_i[j].t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; + results_i[j].t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; // make sure that segments are not overlapping if (!ctx->result_all.empty()) { @@ -3244,14 +3347,14 @@ int whisper_full_parallel( } } - ctx->t_mel_us += ctxs[i].t_mel_us; + ctx->t_mel_us += ctxs[i].t_mel_us; ctx->t_sample_us += ctxs[i].t_sample_us; ctx->t_encode_us += ctxs[i].t_encode_us; ctx->t_decode_us += ctxs[i].t_decode_us; } // average the timings - ctx->t_mel_us /= n_processors; + ctx->t_mel_us /= n_processors; ctx->t_sample_us /= n_processors; ctx->t_encode_us /= n_processors; ctx->t_decode_us /= n_processors; @@ -3260,7 +3363,13 @@ int whisper_full_parallel( fprintf(stderr, "\n"); fprintf(stderr, "%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); for (int i = 0; i < n_processors - 1; ++i) { - fprintf(stderr, "%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str()); + fprintf( + stderr, + "%s: split %d - %s\n", + __func__, + (i + 1), + to_timestamp(100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t).c_str() + ); } fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__); @@ -3319,11 +3428,11 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int // static int timestamp_to_sample(int64_t t, int n_samples) { - return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); + return std::max(0, std::min((int) n_samples - 1, (int) ((t * WHISPER_SAMPLE_RATE) / 100))); } static int64_t sample_to_timestamp(int i_sample) { - return (100*i_sample)/WHISPER_SAMPLE_RATE; + return (100 * i_sample) / WHISPER_SAMPLE_RATE; } // a cost-function / heuristic that is high for text that takes longer to pronounce @@ -3365,17 +3474,18 @@ static std::vector get_signal_energy(const float * signal, int n_samples, sum += fabs(signal[i + j]); } } - result[i] = sum/(2*hw + 1); + result[i] = sum / (2 * hw + 1); } return result; } static void whisper_exp_compute_token_level_timestamps( - struct whisper_context * ctx, - int i_segment, - float thold_pt, - float thold_ptsum) { + struct whisper_context * ctx, + int i_segment, + float thold_pt, + float thold_ptsum +) { auto & segment = ctx->result_all[i_segment]; auto & tokens = segment.tokens; @@ -3411,19 +3521,19 @@ static void whisper_exp_compute_token_level_timestamps( if (j == 0) { if (token.id == whisper_token_beg(ctx)) { - tokens[j ].t0 = t0; - tokens[j ].t1 = t0; + tokens[j].t0 = t0; + tokens[j].t1 = t0; tokens[j + 1].t0 = t0; t_beg = t0; t_last = t0; tid_last = whisper_token_beg(ctx); } else { - tokens[j ].t0 = t_last; + tokens[j].t0 = t_last; } } - const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx)); + const int64_t tt = t_beg + 2 * (token.tid - whisper_token_beg(ctx)); tokens[j].id = token.id; tokens[j].tid = token.tid; @@ -3438,7 +3548,7 @@ static void whisper_exp_compute_token_level_timestamps( tokens[j - 1].t1 = tt; } tokens[j].t0 = tt; - tid_last = token.tid; + tid_last = token.tid; } } @@ -3469,16 +3579,16 @@ static void whisper_exp_compute_token_level_timestamps( psum += tokens[j].vlen; } - //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); + // printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); const double dt = tokens[p1].t1 - tokens[p0].t0; // split the time proportionally to the voice length for (int j = p0 + 1; j <= p1; j++) { - const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum; + const double ct = tokens[j - 1].t0 + dt * tokens[j - 1].vlen / psum; tokens[j - 1].t1 = ct; - tokens[j ].t0 = ct; + tokens[j].t0 = ct; } } @@ -3507,7 +3617,7 @@ static void whisper_exp_compute_token_level_timestamps( // VAD // expand or contract tokens based on voice activity { - const int hw = WHISPER_SAMPLE_RATE/8; + const int hw = WHISPER_SAMPLE_RATE / 8; for (int j = 0; j < n; j++) { if (tokens[j].id >= whisper_token_eot(ctx)) { @@ -3528,7 +3638,7 @@ static void whisper_exp_compute_token_level_timestamps( sum += ctx->energy[k]; } - const float thold = 0.5*sum/ns; + const float thold = 0.5 * sum / ns; { int k = s0; @@ -3546,7 +3656,7 @@ static void whisper_exp_compute_token_level_timestamps( while (ctx->energy[k] < thold && k < s1) { k++; } - s0 = k; + s0 = k; tokens[j].t0 = sample_to_timestamp(k); } } @@ -3567,7 +3677,7 @@ static void whisper_exp_compute_token_level_timestamps( while (ctx->energy[k] < thold && k > s0) { k--; } - s1 = k; + s1 = k; tokens[j].t1 = sample_to_timestamp(k); } } @@ -3589,11 +3699,12 @@ static void whisper_exp_compute_token_level_timestamps( //} // debug info - //for (int j = 0; j < n; ++j) { + // for (int j = 0; j < n; ++j) { // const auto & token = tokens[j]; // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]"; // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, - // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id)); + // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, + // whisper_token_to_str(ctx, token.id)); // if (tokens[j].id >= whisper_token_eot(ctx)) { // continue; diff --git a/whisper.h b/whisper.h index e36b761..3328288 100644 --- a/whisper.h +++ b/whisper.h @@ -1,8 +1,8 @@ #ifndef WHISPER_H #define WHISPER_H -#include #include +#include #ifdef WHISPER_SHARED # ifdef _WIN32 @@ -12,7 +12,7 @@ # define WHISPER_API __declspec(dllimport) # endif # else -# define WHISPER_API __attribute__ ((visibility ("default"))) +# define WHISPER_API __attribute__((visibility("default"))) # endif #else # define WHISPER_API @@ -28,301 +28,295 @@ extern "C" { #endif - // - // C interface - // - // The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads - // concurrently. - // - // Basic usage: - // - // #include "whisper.h" - // - // ... - // - // struct whisper_context * ctx = whisper_init("/path/to/ggml-base.en.bin"); - // - // if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { - // fprintf(stderr, "failed to process audio\n"); - // return 7; - // } - // - // const int n_segments = whisper_full_n_segments(ctx); - // for (int i = 0; i < n_segments; ++i) { - // const char * text = whisper_full_get_segment_text(ctx, i); - // printf("%s", text); - // } - // - // whisper_free(ctx); - // - // ... - // - // This is a demonstration of the most straightforward usage of the library. - // "pcmf32" contains the RAW audio data in 32-bit floating point format. - // - // The interface also allows for more fine-grained control over the computation, but it requires a deeper - // understanding of how the model works. - // - - struct whisper_context; - - typedef int whisper_token; - - typedef struct whisper_token_data { - whisper_token id; // token id - whisper_token tid; // forced timestamp token id - - float p; // probability of the token - float pt; // probability of the timestamp token - float ptsum; // sum of probabilities of all timestamp tokens - - // token-level timestamp data - // do not use if you haven't computed token-level timestamps - int64_t t0; // start time of the token - int64_t t1; // end time of the token - - float vlen; // voice length of the token - } whisper_token_data; - - // Allocates all memory needed for the model and loads the model from the given file. - // Returns NULL on failure. - WHISPER_API struct whisper_context * whisper_init(const char * path_model); - - // Frees all memory allocated by the model. - WHISPER_API void whisper_free(struct whisper_context * ctx); - - // Convert RAW PCM audio to log mel spectrogram. - // The resulting spectrogram is stored inside the provided whisper context. - // Returns 0 on success - WHISPER_API int whisper_pcm_to_mel( - struct whisper_context * ctx, - const float * samples, - int n_samples, - int n_threads); - - // This can be used to set a custom log mel spectrogram inside the provided whisper context. - // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. - // n_mel must be 80 - // Returns 0 on success - WHISPER_API int whisper_set_mel( - struct whisper_context * ctx, - const float * data, - int n_len, - int n_mel); - - // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. - // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. - // offset can be used to specify the offset of the first frame in the spectrogram. - // Returns 0 on success - WHISPER_API int whisper_encode( - struct whisper_context * ctx, - int offset, - int n_threads); - - // Run the Whisper decoder to obtain the logits and probabilities for the next token. - // Make sure to call whisper_encode() first. - // tokens + n_tokens is the provided context for the decoder. - // n_past is the number of tokens to use from previous decoder calls. - // Returns 0 on success - WHISPER_API int whisper_decode( - struct whisper_context * ctx, - const whisper_token * tokens, - int n_tokens, - int n_past, - int n_threads); - - // Token sampling methods. - // These are provided for convenience and can be used after each call to whisper_decode(). - // You can also implement your own sampling method using the whisper_get_probs() function. - // whisper_sample_best() returns the token with the highest probability - // whisper_sample_timestamp() returns the most probable timestamp token - WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); - WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial); - - // Convert the provided text into tokens. - // The tokens pointer must be large enough to hold the resulting tokens. - // Returns the number of tokens on success, no more than n_max_tokens - // Returns -1 on failure - // TODO: not sure if correct - WHISPER_API int whisper_tokenize( - struct whisper_context * ctx, - const char * text, - whisper_token * tokens, - int n_max_tokens); - - // Largest language id (i.e. number of available languages - 1) - WHISPER_API int whisper_lang_max_id(); - - // Return the id of the specified language, returns -1 if not found - // Examples: - // "de" -> 2 - // "german" -> 2 - WHISPER_API int whisper_lang_id(const char * lang); - - // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found - WHISPER_API const char * whisper_lang_str(int id); - - // Use mel data at offset_ms to try and auto-detect the spoken language - // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first - // Returns the top language id or negative on failure - // If not null, fills the lang_probs array with the probabilities of all languages - // The array must be whispe_lang_max_id() + 1 in size - // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 - WHISPER_API int whisper_lang_auto_detect( - struct whisper_context * ctx, - int offset_ms, - int n_threads, - float * lang_probs); - - WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length - WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); - WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); - WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); - WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); - - // The probabilities for the next token - WHISPER_API float * whisper_get_probs(struct whisper_context * ctx); - - // Token Id -> String. Uses the vocabulary in the provided context - WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); - - // Special tokens - WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); - - // Task tokens - WHISPER_API whisper_token whisper_token_translate (void); - WHISPER_API whisper_token whisper_token_transcribe(void); - - // Performance information - WHISPER_API void whisper_print_timings(struct whisper_context * ctx); - WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); - - // Print system information - WHISPER_API const char * whisper_print_system_info(void); - - //////////////////////////////////////////////////////////////////////////// - - // Available sampling strategies - enum whisper_sampling_strategy { - WHISPER_SAMPLING_GREEDY, // Always select the most probable token - WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet! - }; - - // Text segment callback - // Called on every newly generated text segment - // Use the whisper_full_...() functions to obtain the text segments - typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data); - - // Encoder begin callback - // If not NULL, called before the encoder starts - // If it returns false, the computation is aborted - typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data); - - // Parameters for the whisper_full() function - // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp: - // whisper_full_default_params() - struct whisper_full_params { - enum whisper_sampling_strategy strategy; - - int n_threads; - int n_max_text_ctx; - int offset_ms; // start offset in ms - int duration_ms; // audio duration to process in ms - - bool translate; - bool no_context; - bool single_segment; // force single segment output (useful for streaming) - bool print_special; - bool print_progress; - bool print_realtime; - bool print_timestamps; - - // [EXPERIMENTAL] token-level timestamps - bool token_timestamps; // enable token-level timestamps - float thold_pt; // timestamp token probability threshold (~0.01) - float thold_ptsum; // timestamp token sum probability threshold (~0.01) - int max_len; // max segment length in characters - int max_tokens; // max tokens per segment (0 = no limit) - - // [EXPERIMENTAL] speed-up techniques - bool speed_up; // speed-up the audio by 2x using Phase Vocoder - int audio_ctx; // overwrite the audio context size (0 = use default) - - // tokens to provide the whisper model as initial prompt - // these are prepended to any existing text context from a previous call - const whisper_token * prompt_tokens; - int prompt_n_tokens; - - // for auto-detection, set to nullptr, "" or "auto" - const char * language; - - struct { - int n_past; - } greedy; - - struct { - int n_past; - int beam_width; - int n_best; - } beam_search; - - whisper_new_segment_callback new_segment_callback; - void * new_segment_callback_user_data; - - whisper_encoder_begin_callback encoder_begin_callback; - void * encoder_begin_callback_user_data; - }; - - WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); - - // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text - // Uses the specified decoding strategy to obtain the text. - WHISPER_API int whisper_full( - struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples); - - // Split the input audio in chunks and process each chunk separately using whisper_full() - // It seems this approach can offer some speedup in some cases. - // However, the transcription accuracy can be worse at the beginning and end of each chunk. - WHISPER_API int whisper_full_parallel( - struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples, - int n_processors); - - // Number of generated text segments. - // A segment can be a few words, a sentence, or even a paragraph. - WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx); - - // Get the start and end time of the specified segment. - WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment); - WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment); - - // Get the text of the specified segment. - WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment); - - // Get number of tokens in the specified segment. - WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment); - - // Get the token text of the specified token in the specified segment. - WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token); - WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); - - // Get token data for the specified token in the specified segment. - // This contains probabilities, timestamps, etc. - WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); - - // Get the probability of the specified token in the specified segment. - WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); +// +// C interface +// +// The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads +// concurrently. +// +// Basic usage: +// +// #include "whisper.h" +// +// ... +// +// struct whisper_context * ctx = whisper_init("/path/to/ggml-base.en.bin"); +// +// if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { +// fprintf(stderr, "failed to process audio\n"); +// return 7; +// } +// +// const int n_segments = whisper_full_n_segments(ctx); +// for (int i = 0; i < n_segments; ++i) { +// const char * text = whisper_full_get_segment_text(ctx, i); +// printf("%s", text); +// } +// +// whisper_free(ctx); +// +// ... +// +// This is a demonstration of the most straightforward usage of the library. +// "pcmf32" contains the RAW audio data in 32-bit floating point format. +// +// The interface also allows for more fine-grained control over the computation, but it requires a deeper +// understanding of how the model works. +// + +struct whisper_context; + +typedef int whisper_token; + +typedef struct whisper_token_data { + whisper_token id; // token id + whisper_token tid; // forced timestamp token id + + float p; // probability of the token + float pt; // probability of the timestamp token + float ptsum; // sum of probabilities of all timestamp tokens + + // token-level timestamp data + // do not use if you haven't computed token-level timestamps + int64_t t0; // start time of the token + int64_t t1; // end time of the token + + float vlen; // voice length of the token +} whisper_token_data; + +// Allocates all memory needed for the model and loads the model from the given file. +// Returns NULL on failure. +WHISPER_API struct whisper_context * whisper_init(const char * path_model); + +// Frees all memory allocated by the model. +WHISPER_API void whisper_free(struct whisper_context * ctx); + +// Convert RAW PCM audio to log mel spectrogram. +// The resulting spectrogram is stored inside the provided whisper context. +// Returns 0 on success +WHISPER_API int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads); + +// This can be used to set a custom log mel spectrogram inside the provided whisper context. +// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. +// n_mel must be 80 +// Returns 0 on success +WHISPER_API int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel); + +// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. +// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. +// offset can be used to specify the offset of the first frame in the spectrogram. +// Returns 0 on success +WHISPER_API int whisper_encode(struct whisper_context * ctx, int offset, int n_threads); + +// Run the Whisper decoder to obtain the logits and probabilities for the next token. +// Make sure to call whisper_encode() first. +// tokens + n_tokens is the provided context for the decoder. +// n_past is the number of tokens to use from previous decoder calls. +// Returns 0 on success +WHISPER_API int whisper_decode( + struct whisper_context * ctx, + const whisper_token * tokens, + int n_tokens, + int n_past, + int n_threads +); + +// Token sampling methods. +// These are provided for convenience and can be used after each call to whisper_decode(). +// You can also implement your own sampling method using the whisper_get_probs() function. +// whisper_sample_best() returns the token with the highest probability +// whisper_sample_timestamp() returns the most probable timestamp token +WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); +WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial); + +// Convert the provided text into tokens. +// The tokens pointer must be large enough to hold the resulting tokens. +// Returns the number of tokens on success, no more than n_max_tokens +// Returns -1 on failure +// TODO: not sure if correct +WHISPER_API int whisper_tokenize( + struct whisper_context * ctx, + const char * text, + whisper_token * tokens, + int n_max_tokens +); + +// Largest language id (i.e. number of available languages - 1) +WHISPER_API int whisper_lang_max_id(); + +// Return the id of the specified language, returns -1 if not found +// Examples: +// "de" -> 2 +// "german" -> 2 +WHISPER_API int whisper_lang_id(const char * lang); + +// Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found +WHISPER_API const char * whisper_lang_str(int id); + +// Use mel data at offset_ms to try and auto-detect the spoken language +// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first +// Returns the top language id or negative on failure +// If not null, fills the lang_probs array with the probabilities of all languages +// The array must be whispe_lang_max_id() + 1 in size +// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 +WHISPER_API int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs +); + +WHISPER_API int whisper_n_len(struct whisper_context * ctx); // mel length +WHISPER_API int whisper_n_vocab(struct whisper_context * ctx); +WHISPER_API int whisper_n_text_ctx(struct whisper_context * ctx); +WHISPER_API int whisper_n_audio_ctx(struct whisper_context * ctx); +WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); + +// The probabilities for the next token +WHISPER_API float * whisper_get_probs(struct whisper_context * ctx); + +// Token Id -> String. Uses the vocabulary in the provided context +WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); + +// Special tokens +WHISPER_API whisper_token whisper_token_eot(struct whisper_context * ctx); +WHISPER_API whisper_token whisper_token_sot(struct whisper_context * ctx); +WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); +WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); +WHISPER_API whisper_token whisper_token_not(struct whisper_context * ctx); +WHISPER_API whisper_token whisper_token_beg(struct whisper_context * ctx); +WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); + +// Task tokens +WHISPER_API whisper_token whisper_token_translate(void); +WHISPER_API whisper_token whisper_token_transcribe(void); + +// Performance information +WHISPER_API void whisper_print_timings(struct whisper_context * ctx); +WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); + +// Print system information +WHISPER_API const char * whisper_print_system_info(void); + +//////////////////////////////////////////////////////////////////////////// + +// Available sampling strategies +enum whisper_sampling_strategy { + WHISPER_SAMPLING_GREEDY, // Always select the most probable token + WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet! +}; + +// Text segment callback +// Called on every newly generated text segment +// Use the whisper_full_...() functions to obtain the text segments +typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data); + +// Encoder begin callback +// If not NULL, called before the encoder starts +// If it returns false, the computation is aborted +typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data); + +// Parameters for the whisper_full() function +// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp: +// whisper_full_default_params() +struct whisper_full_params { + enum whisper_sampling_strategy strategy; + + int n_threads; + int n_max_text_ctx; + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms + + bool translate; + bool no_context; + bool single_segment; // force single segment output (useful for streaming) + bool print_special; + bool print_progress; + bool print_realtime; + bool print_timestamps; + + // [EXPERIMENTAL] token-level timestamps + bool token_timestamps; // enable token-level timestamps + float thold_pt; // timestamp token probability threshold (~0.01) + float thold_ptsum; // timestamp token sum probability threshold (~0.01) + int max_len; // max segment length in characters + int max_tokens; // max tokens per segment (0 = no limit) + + // [EXPERIMENTAL] speed-up techniques + bool speed_up; // speed-up the audio by 2x using Phase Vocoder + int audio_ctx; // overwrite the audio context size (0 = use default) + + // tokens to provide the whisper model as initial prompt + // these are prepended to any existing text context from a previous call + const whisper_token * prompt_tokens; + int prompt_n_tokens; + + // for auto-detection, set to nullptr, "" or "auto" + const char * language; + + struct { + int n_past; + } greedy; + + struct { + int n_past; + int beam_width; + int n_best; + } beam_search; + + whisper_new_segment_callback new_segment_callback; + void * new_segment_callback_user_data; + + whisper_encoder_begin_callback encoder_begin_callback; + void * encoder_begin_callback_user_data; +}; + +WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); + +// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text +// Uses the specified decoding strategy to obtain the text. +WHISPER_API int whisper_full( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples +); + +// Split the input audio in chunks and process each chunk separately using whisper_full() +// It seems this approach can offer some speedup in some cases. +// However, the transcription accuracy can be worse at the beginning and end of each chunk. +WHISPER_API int whisper_full_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + int n_processors +); + +// Number of generated text segments. +// A segment can be a few words, a sentence, or even a paragraph. +WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx); + +// Get the start and end time of the specified segment. +WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment); +WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment); + +// Get the text of the specified segment. +WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment); + +// Get number of tokens in the specified segment. +WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment); + +// Get the token text of the specified token in the specified segment. +WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token); +WHISPER_API whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token); + +// Get token data for the specified token in the specified segment. +// This contains probabilities, timestamps, etc. +WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); + +// Get the probability of the specified token in the specified segment. +WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); #ifdef __cplusplus }