From 65f2c1bedc35439714536b0be7723fac8673d67a Mon Sep 17 00:00:00 2001 From: Todd Fisher Date: Mon, 13 Feb 2023 11:18:24 -0500 Subject: [PATCH] adding ruby bindings --- bindings/ruby/ext/extconf.rb | 20 + bindings/ruby/ext/ggml.c | 8616 +++++++++++++++++++++++++++ bindings/ruby/ext/ggml.h | 748 +++ bindings/ruby/ext/ruby_whisper.c | 92 + bindings/ruby/ext/ruby_whisper.h | 14 + bindings/ruby/ext/whisper.cpp | 4814 +++++++++++++++ bindings/ruby/ext/whisper.h | 379 ++ bindings/ruby/tests/test_whisper.rb | 23 + 8 files changed, 14706 insertions(+) create mode 100644 bindings/ruby/ext/extconf.rb create mode 100644 bindings/ruby/ext/ggml.c create mode 100644 bindings/ruby/ext/ggml.h create mode 100644 bindings/ruby/ext/ruby_whisper.c create mode 100644 bindings/ruby/ext/ruby_whisper.h create mode 100644 bindings/ruby/ext/whisper.cpp create mode 100644 bindings/ruby/ext/whisper.h create mode 100644 bindings/ruby/tests/test_whisper.rb diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb new file mode 100644 index 0000000..0421cb2 --- /dev/null +++ b/bindings/ruby/ext/extconf.rb @@ -0,0 +1,20 @@ +require 'mkmf' +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .") +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .") +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .") +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .") + + +# need to use c++ compiler flags +$CXXFLAGS << ' -std=c++11' +# Set to true when building binary gems +if enable_config('static-stdlib', false) + $LDFLAGS << ' -static-libgcc -static-libstdc++' +end + +if enable_config('march-tune-native', false) + $CFLAGS << ' -march=native -mtune=native' + $CXXFLAGS << ' -march=native -mtune=native' +end + +create_makefile('whisper') diff --git a/bindings/ruby/ext/ggml.c b/bindings/ruby/ext/ggml.c new file mode 100644 index 0000000..d67612c --- /dev/null +++ b/bindings/ruby/ext/ggml.c @@ -0,0 +1,8616 @@ +#include "ggml.h" + +#if defined(_MSC_VER) || defined(__MINGW32__) +#include // using malloc.h with MSC/MINGW +#elif !defined(__FreeBSD__) +#include +#endif + +#include +#include +#include +#include +#include +#include +#include + +// if C99 - static_assert is noop +// ref: https://stackoverflow.com/a/53923785/4039976 +#ifndef static_assert +#define static_assert(cond, msg) struct global_scope_noop_trick +#endif + +#if defined _MSC_VER || defined(__MINGW32__) + +#if !defined(__MINGW32__) +#include +#else +// ref: https://github.com/ggerganov/whisper.cpp/issues/168 +#include +#include +#endif + +typedef volatile LONG atomic_int; +typedef atomic_int atomic_bool; + +static void atomic_store(atomic_int* ptr, LONG val) { + InterlockedExchange(ptr, val); +} +static LONG atomic_load(atomic_int* ptr) { + return InterlockedCompareExchange(ptr, 0, 0); +} +static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) { + return InterlockedExchangeAdd(ptr, inc); +} +static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) { + return atomic_fetch_add(ptr, -(dec)); +} + +typedef HANDLE pthread_t; + +typedef DWORD thread_ret_t; +static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) { + HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL); + if (handle == NULL) + { + return EAGAIN; + } + + *out = handle; + return 0; +} + +static int pthread_join(pthread_t thread, void* unused) { + return (int) WaitForSingleObject(thread, INFINITE); +} + +static int sched_yield (void) { + Sleep (0); + return 0; +} +#else +#include +#include + +typedef void* thread_ret_t; +#endif + +#ifdef __HAIKU__ +#define static_assert(cond, msg) _Static_assert(cond, msg) +#endif + +/*#define GGML_PERF*/ +#define GGML_DEBUG 0 +#define GGML_GELU_FP16 + +#define GGML_SOFT_MAX_UNROLL 4 +#define GGML_VEC_DOT_UNROLL 2 + +#ifdef GGML_USE_ACCELERATE +// uncomment to use vDSP for soft max computation +// note: not sure if it is actually faster +//#define GGML_SOFT_MAX_ACCELERATE +#endif + +#if UINTPTR_MAX == 0xFFFFFFFF + #define GGML_MEM_ALIGN 4 +#else + #define GGML_MEM_ALIGN 16 +#endif + +#define UNUSED(x) (void)(x) +#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0) + +#define GGML_ASSERT(x) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +#ifdef GGML_USE_ACCELERATE +#include +#elif GGML_USE_OPENBLAS +#include +#endif + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +// floating point type used to accumulate sums +typedef double ggml_float; + +// 16-bit float +// on Arm, we use __fp16 +// on x86, we use uint16_t +#ifdef __ARM_NEON + +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include + +#define GGML_COMPUTE_FP16_TO_FP32(x) (x) +#define GGML_COMPUTE_FP32_TO_FP16(x) (x) + +#define GGML_FP16_TO_FP32(x) (x) +#define GGML_FP32_TO_FP16(x) (x) + +#else + +#ifdef __wasm_simd128__ +#include +#else +#ifdef __POWER9_VECTOR__ +#include +#undef bool +#define bool _Bool +#else +#include +#endif +#endif + +#ifdef __F16C__ + +#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) + +#else + +// FP16 <-> FP32 +// ref: https://github.com/Maratyszcza/FP16 + +static inline float fp32_from_bits(uint32_t w) { + union { + uint32_t as_bits; + float as_value; + } fp32; + fp32.as_bits = w; + return fp32.as_value; +} + +static inline uint32_t fp32_to_bits(float f) { + union { + float as_value; + uint32_t as_bits; + } fp32; + fp32.as_value = f; + return fp32.as_bits; +} + +static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + const uint32_t w = (uint32_t) h << 16; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t two_w = w + w; + + const uint32_t exp_offset = UINT32_C(0xE0) << 23; +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float exp_scale = 0x1.0p-112f; +#else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); +#endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; +#else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); +#endif + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); +} + +#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) + +#endif // __F16C__ + +#endif // __ARM_NEON + +// +// global data +// + +// precomputed gelu table for f16 (128 KB) +static ggml_fp16_t table_gelu_f16[1 << 16]; + +// precomputed exp table for f16 (128 KB) +static ggml_fp16_t table_exp_f16[1 << 16]; + +// precomputed f32 table for f16 (256 KB) +static float table_f32_f16[1 << 16]; + +// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, +// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. +#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) + +inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { + uint16_t s; + memcpy(&s, &f, sizeof(uint16_t)); + return table_f32_f16[s]; +} + +#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) +#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) + +#endif + +// note: do not use these inside ggml.c +// these are meant to be used via the ggml.h API +float ggml_fp16_to_fp32(ggml_fp16_t x) { + return GGML_FP16_TO_FP32(x); +} + +ggml_fp16_t ggml_fp32_to_fp16(float x) { + return GGML_FP32_TO_FP16(x); +} + +// +// timing +// + +#if defined(_MSC_VER) || defined(__MINGW32__) +static int64_t timer_freq; +void ggml_time_init(void) { + LARGE_INTEGER frequency; + QueryPerformanceFrequency(&frequency); + timer_freq = frequency.QuadPart; +} +int64_t ggml_time_ms(void) { + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return (t.QuadPart * 1000) / timer_freq; +} +int64_t ggml_time_us(void) { + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return (t.QuadPart * 1000000) / timer_freq; +} +#else +void ggml_time_init(void) {} +int64_t ggml_time_ms(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000; +} + +int64_t ggml_time_us(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000; +} +#endif + +int64_t ggml_cycles(void) { + return clock(); +} + +int64_t ggml_cycles_per_ms(void) { + return CLOCKS_PER_SEC/1000; +} + +#ifdef GGML_PERF +#define ggml_perf_time_ms() ggml_time_ms() +#define ggml_perf_time_us() ggml_time_us() +#define ggml_perf_cycles() ggml_cycles() +#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms() +#else +#define ggml_perf_time_ms() 0 +#define ggml_perf_time_us() 0 +#define ggml_perf_cycles() 0 +#define ggml_perf_cycles_per_ms() 0 +#endif + +// +// cache line +// + +#if defined(__cpp_lib_hardware_interference_size) +#define CACHE_LINE_SIZE hardware_destructive_interference_size +#else +#if defined(__POWER9_VECTOR__) +#define CACHE_LINE_SIZE 128 +#else +#define CACHE_LINE_SIZE 64 +#endif +#endif + +static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); + +// +// simd mappings +// + +// we define a common set of C macros which map to specific intrinsics based on the current architecture +// we then implement the fundamental computation operations below using only these macros +// adding support for new architectures requires to define the corresponding SIMD macros +// +// GGML_F32_STEP / GGML_F16_STEP +// number of elements to process in a single step +// +// GGML_F32_EPR / GGML_F16_EPR +// number of elements to fit in a single register +// + +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) + +#define GGML_SIMD + +// F32 NEON + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 float32x4_t +#define GGML_F32x4_ZERO vdupq_n_f32(0.0f) +#define GGML_F32x4_SET1(x) vdupq_n_f32(x) +#define GGML_F32x4_LOAD vld1q_f32 +#define GGML_F32x4_STORE vst1q_f32 +#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) +#define GGML_F32x4_ADD vaddq_f32 +#define GGML_F32x4_MUL vmulq_f32 +#if defined(__ARM_FEATURE_QRDMX) + #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) +#else + #define GGML_F32x4_REDUCE_ONE(x) \ + (vgetq_lane_f32(x, 0) + \ + vgetq_lane_f32(x, 1) + \ + vgetq_lane_f32(x, 2) + \ + vgetq_lane_f32(x, 3)) +#endif +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = vaddq_f32(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = vaddq_f32(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = vaddq_f32(x[8*i], x[8*i+4]); \ + } \ + res = GGML_F32x4_REDUCE_ONE(x[0]); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + #define GGML_F16_STEP 32 + #define GGML_F16_EPR 8 + + #define GGML_F16x8 float16x8_t + #define GGML_F16x8_ZERO vdupq_n_f16(0.0f) + #define GGML_F16x8_SET1(x) vdupq_n_f16(x) + #define GGML_F16x8_LOAD vld1q_f16 + #define GGML_F16x8_STORE vst1q_f16 + #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) + #define GGML_F16x8_ADD vaddq_f16 + #define GGML_F16x8_MUL vmulq_f16 + #define GGML_F16x8_REDUCE(res, x) \ + { \ + for (int i = 0; i < GGML_F16_ARR/2; ++i) { \ + x[2*i] = vaddq_f16(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/4; ++i) { \ + x[4*i] = vaddq_f16(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/8; ++i) { \ + x[8*i] = vaddq_f16(x[8*i], x[8*i+4]); \ + } \ + const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ + const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ + res = vaddvq_f32(vaddq_f32(t0, t1)); \ + } + + #define GGML_F16_VEC GGML_F16x8 + #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO + #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 + #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i]) + #define GGML_F16_VEC_FMA GGML_F16x8_FMA + #define GGML_F16_VEC_ADD GGML_F16x8_ADD + #define GGML_F16_VEC_MUL GGML_F16x8_MUL + #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE +#else + // if FP16 vector arithmetic is not supported, we use FP32 instead + // and take advantage of the vcvt_ functions to convert to/from FP16 + + #define GGML_F16_STEP 16 + #define GGML_F16_EPR 4 + + #define GGML_F32Cx4 float32x4_t + #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) + #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x) + #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x)) + #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) + #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) + #define GGML_F32Cx4_ADD vaddq_f32 + #define GGML_F32Cx4_MUL vmulq_f32 + #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + + #define GGML_F16_VEC GGML_F32Cx4 + #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO + #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 + #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) + #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA + #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD + #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL + #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE +#endif + +#elif defined(__AVX__) + +#define GGML_SIMD + +// F32 AVX + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 8 + +#define GGML_F32x8 __m256 +#define GGML_F32x8_ZERO _mm256_setzero_ps() +#define GGML_F32x8_SET1(x) _mm256_set1_ps(x) +#define GGML_F32x8_LOAD _mm256_loadu_ps +#define GGML_F32x8_STORE _mm256_storeu_ps +#if defined(__FMA__) + #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a) +#else + #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a) +#endif +#define GGML_F32x8_ADD _mm256_add_ps +#define GGML_F32x8_MUL _mm256_mul_ps +#define GGML_F32x8_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = _mm256_add_ps(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = _mm256_add_ps(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = _mm256_add_ps(x[8*i], x[8*i+4]); \ + } \ + const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \ + _mm256_extractf128_ps(x[0], 1)); \ + const __m128 t1 = _mm_hadd_ps(t0, t0); \ + res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ +} +// TODO: is this optimal ? + +#define GGML_F32_VEC GGML_F32x8 +#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x8_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD +#define GGML_F32_VEC_STORE GGML_F32x8_STORE +#define GGML_F32_VEC_FMA GGML_F32x8_FMA +#define GGML_F32_VEC_ADD GGML_F32x8_ADD +#define GGML_F32_VEC_MUL GGML_F32x8_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE + +// F16 AVX + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 8 + +// F16 arithmetic is not supported by AVX, so we use F32 instead +// we take advantage of the _mm256_cvt intrinsics to convert F16 <-> F32 + +#define GGML_F32Cx8 __m256 +#define GGML_F32Cx8_ZERO _mm256_setzero_ps() +#define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x) +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) +#define GGML_F32Cx8_FMA GGML_F32x8_FMA +#define GGML_F32Cx8_ADD _mm256_add_ps +#define GGML_F32Cx8_MUL _mm256_mul_ps +#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE + +#define GGML_F16_VEC GGML_F32Cx8 +#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE + +#elif defined(__POWER9_VECTOR__) + +#define GGML_SIMD + +// F32 POWER9 + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 vector float +#define GGML_F32x4_ZERO 0.0f +#define GGML_F32x4_SET1 vec_splats +#define GGML_F32x4_LOAD(p) vec_xl(0, p) +#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p) +#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) +#define GGML_F32x4_ADD vec_add +#define GGML_F32x4_MUL vec_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = vec_add(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = vec_add(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = vec_add(x[8*i], x[8*i+4]); \ + } \ + res = vec_extract(x[0], 0) + \ + vec_extract(x[0], 1) + \ + vec_extract(x[0], 2) + \ + vec_extract(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 POWER9 +#define GGML_F16_STEP GGML_F32_STEP +#define GGML_F16_EPR GGML_F32_EPR +#define GGML_F16_VEC GGML_F32x4 +#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F16_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F16_VEC_FMA GGML_F32x4_FMA +#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE +// Use vec_xl, not vec_ld, in case the load address is not aligned. +#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \ + vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \ + vec_extract_fp32_from_shortl(vec_xl(0, p)) +#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i] +#define GGML_F16_VEC_STORE(p, r, i) \ + if (i & 0x1) \ + vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \ + r[i - GGML_ENDIAN_BYTE(0)]), \ + 0, p - GGML_F16_EPR) + +#elif defined(__wasm_simd128__) + +#define GGML_SIMD + +// F32 WASM + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 v128_t +#define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f) +#define GGML_F32x4_SET1(x) wasm_f32x4_splat(x) +#define GGML_F32x4_LOAD wasm_v128_load +#define GGML_F32x4_STORE wasm_v128_store +#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a) +#define GGML_F32x4_ADD wasm_f32x4_add +#define GGML_F32x4_MUL wasm_f32x4_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 WASM + +#define GGML_F16_STEP 16 +#define GGML_F16_EPR 4 + +inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) { + float tmp[4]; + + tmp[0] = GGML_FP16_TO_FP32(p[0]); + tmp[1] = GGML_FP16_TO_FP32(p[1]); + tmp[2] = GGML_FP16_TO_FP32(p[2]); + tmp[3] = GGML_FP16_TO_FP32(p[3]); + + return wasm_v128_load(tmp); +} + +inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) { + float tmp[4]; + + wasm_v128_store(tmp, x); + + p[0] = GGML_FP32_TO_FP16(tmp[0]); + p[1] = GGML_FP32_TO_FP16(tmp[1]); + p[2] = GGML_FP32_TO_FP16(tmp[2]); + p[3] = GGML_FP32_TO_FP16(tmp[3]); +} + +#define GGML_F16x4 v128_t +#define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f) +#define GGML_F16x4_SET1(x) wasm_f32x4_splat(x) +#define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x) +#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y) +#define GGML_F16x4_FMA GGML_F32x4_FMA +#define GGML_F16x4_ADD wasm_f32x4_add +#define GGML_F16x4_MUL wasm_f32x4_mul +#define GGML_F16x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F16_ARR/2; ++i) { \ + x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/4; ++i) { \ + x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F16_ARR/8; ++i) { \ + x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define GGML_F16_VEC GGML_F16x4 +#define GGML_F16_VEC_ZERO GGML_F16x4_ZERO +#define GGML_F16_VEC_SET1 GGML_F16x4_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F16x4_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F16x4_FMA +#define GGML_F16_VEC_ADD GGML_F16x4_ADD +#define GGML_F16_VEC_MUL GGML_F16x4_MUL +#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE + +#elif defined(__SSE3__) + +#define GGML_SIMD + +// F32 SSE + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 __m128 +#define GGML_F32x4_ZERO _mm_setzero_ps() +#define GGML_F32x4_SET1(x) _mm_set1_ps(x) +#define GGML_F32x4_LOAD _mm_loadu_ps +#define GGML_F32x4_STORE _mm_storeu_ps +#if defined(__FMA__) + // TODO: Does this work? + #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a) +#else + #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a) +#endif +#define GGML_F32x4_ADD _mm_add_ps +#define GGML_F32x4_MUL _mm_mul_ps +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ + x[2*i] = _mm_add_ps(x[2*i], x[2*i+1]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ + x[4*i] = _mm_add_ps(x[4*i], x[4*i+2]); \ + } \ + for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ + x[8*i] = _mm_add_ps(x[8*i], x[8*i+4]); \ + } \ + const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \ + res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \ +} +// TODO: is this optimal ? + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 SSE + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 4 + +static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) { + float tmp[4]; + + tmp[0] = GGML_FP16_TO_FP32(x[0]); + tmp[1] = GGML_FP16_TO_FP32(x[1]); + tmp[2] = GGML_FP16_TO_FP32(x[2]); + tmp[3] = GGML_FP16_TO_FP32(x[3]); + + return _mm_loadu_ps(tmp); +} + +static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) { + float arr[4]; + + _mm_storeu_ps(arr, y); + + x[0] = GGML_FP32_TO_FP16(arr[0]); + x[1] = GGML_FP32_TO_FP16(arr[1]); + x[2] = GGML_FP32_TO_FP16(arr[2]); + x[3] = GGML_FP32_TO_FP16(arr[3]); +} + +#define GGML_F32Cx4 __m128 +#define GGML_F32Cx4_ZERO _mm_setzero_ps() +#define GGML_F32Cx4_SET1(x) _mm_set1_ps(x) +#define GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x) +#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y) +#define GGML_F32Cx4_FMA GGML_F32x4_FMA +#define GGML_F32Cx4_ADD _mm_add_ps +#define GGML_F32Cx4_MUL _mm_mul_ps +#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + +#define GGML_F16_VEC GGML_F32Cx4 +#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE + +#endif + +// GGML_F32_ARR / GGML_F16_ARR +// number of registers to use per step +#ifdef GGML_SIMD +#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR) +#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) +#endif + +// +// fundamental operations +// + +inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } +inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } +inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } +inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } +inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } +inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } +inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } + +inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { + ggml_float sumf = 0.0; + +#ifdef GGML_SIMD + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += x[i]*y[i]; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + sumf += x[i]*y[i]; + } +#endif + + *s = sumf; +} + +inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { + ggml_float sumf = 0.0; + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + + sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F16_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); + } +#else + for (int i = 0; i < n; ++i) { + sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); + } +#endif + + *s = sumf; +} + +// compute GGML_VEC_DOT_UNROLL dot products at once +// xs - x row stride in bytes +inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { + ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 }; + + ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL]; + + for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { + x[i] = (ggml_fp16_t *) ((char *) xv + i*xs); + } + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); + + sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); + } + } + } + + // reduce sum0..sum3 to sum0 + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + GGML_F16_VEC_REDUCE(sumf[k], sum[k]); + } + + // leftovers + for (int i = np; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]); + } + } +#else + for (int i = 0; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]); + } + } +#endif + + for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { + s[i] = sumf[i]; + } +} + +inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += x[i]*v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] += x[i]*v; + } +#endif +} + +inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + GGML_ASSERT(false); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#else + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + +//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } +inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_MUL(ay[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] *= v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] *= v; + } +#endif +} + +inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); } +inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } +inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); } +inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } +inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } +inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } +inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } + +static const ggml_float GELU_COEF_A = 0.044715; +static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876; + +inline static float ggml_gelu_f32(float x) { + return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x))); +} + +inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + y[i] = table_gelu_f16[i16[i]]; + } +} + +#ifdef GGML_GELU_FP16 +inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(table_gelu_f16[t]); + } +} +#else +inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_f32(x[i]); + } +} +#endif + +inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { +#ifndef GGML_USE_ACCELERATE + ggml_float sum = 0.0; + for (int i = 0; i < n; ++i) { + sum += x[i]; + } + *s = sum; +#else + vDSP_sve(x, 1, s, n); +#endif +} + +inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { +#ifndef GGML_USE_ACCELERATE + ggml_float max = -INFINITY; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + } + *s = max; +#else + vDSP_maxv(x, 1, s, n); +#endif +} + +inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); } + +// +// logging +// + +#if (GGML_DEBUG >= 1) +#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG(...) +#endif + +#if (GGML_DEBUG >= 5) +#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_5(...) +#endif + +#if (GGML_DEBUG >= 10) +#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_10(...) +#endif + +#define GGML_PRINT(...) printf(__VA_ARGS__) + +// +// data types +// + +static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { + sizeof(int8_t ), + sizeof(int16_t), + sizeof(int32_t), + sizeof(ggml_fp16_t), + sizeof(float ), +}; + +static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { + "NONE", + + "DUP", + "ADD", + "SUB", + "MUL", + "DIV", + "SQR", + "SQRT", + "SUM", + "MEAN", + "REPEAT", + "ABS", + "SGN", + "NEG", + "STEP", + "RELU", + "GELU", + "NORM", + + "MUL_MAT", + + "SCALE", + "CPY", + "RESHAPE", + "VIEW", + "PERMUTE", + "TRANSPOSE", + "GET_ROWS", + "DIAG_MASK_INF", + "SOFT_MAX", + "ROPE", + "CONV_1D_1S", + "CONV_1D_2S", + + "FLASH_ATTN", + "FLASH_FF", +}; + +static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { + "none", + + "x", + "x+y", + "x-y", + "x*y", + "x/y", + "x^2", + "√x", + "Σx", + "Σx/n", + "repeat(x)", + "abs(x)", + "sgn(x)", + "-x", + "step(x)", + "relu(x)", + "gelu(x)", + "norm(x)", + + "X*Y", + + "x*v", + "x-\\>y", + "reshape(x)", + "view(x)", + "permute(x)", + "transpose(x)", + "get_rows(x)", + "diag_mask_inf(x)", + "soft_max(x)", + "rope(x)", + "conv_1d_1s(x)", + "conv_1d_2s(x)", + + "flash_attn(x)", + "flash_ff(x)", +}; + +// +// ggml object +// + +struct ggml_object { + size_t offs; + size_t size; + + struct ggml_object * next; + + char padding[8]; +}; + +static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object); + +static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); +static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); + +// +// ggml context +// + +struct ggml_context { + size_t mem_size; + void * mem_buffer; + bool mem_buffer_owned; + + int n_objects; + + struct ggml_object * objects_begin; + struct ggml_object * objects_end; + + struct ggml_scratch scratch; + struct ggml_scratch scratch_save; +}; + +struct ggml_context_container { + bool used; + + struct ggml_context context; +}; + +// +// compute types +// + +enum ggml_task_type { + GGML_TASK_INIT = 0, + GGML_TASK_COMPUTE, + GGML_TASK_FINALIZE, +}; + +struct ggml_compute_params { + enum ggml_task_type type; + + int ith, nth; + + // work buffer for all threads + size_t wsize; + void * wdata; +}; + +// +// ggml state +// + +struct ggml_state { + struct ggml_context_container contexts[GGML_MAX_CONTEXTS]; +}; + +// global state +static struct ggml_state g_state; +static atomic_int g_state_barrier = 0; + +// barrier via spin lock +inline static void ggml_critical_section_start(void) { + int processing = atomic_fetch_add(&g_state_barrier, 1); + + while (processing > 0) { + // wait for other threads to finish + atomic_fetch_sub(&g_state_barrier, 1); + sched_yield(); // TODO: reconsider this + processing = atomic_fetch_add(&g_state_barrier, 1); + } +} + +// TODO: make this somehow automatically executed +// some sort of "sentry" mechanism +inline static void ggml_critical_section_end(void) { + atomic_fetch_sub(&g_state_barrier, 1); +} + +//////////////////////////////////////////////////////////////////////////////// + +void ggml_print_object(const struct ggml_object * obj) { + GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n", + obj->offs, obj->size, (const void *) obj->next); +} + +void ggml_print_objects(const struct ggml_context * ctx) { + struct ggml_object * obj = ctx->objects_begin; + + GGML_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx); + + while (obj != NULL) { + ggml_print_object(obj); + obj = obj->next; + } + + GGML_PRINT("%s: --- end ---\n", __func__); +} + +int ggml_nelements(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; +} + +int ggml_nrows(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; +} + +size_t ggml_nbytes(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type]; +} + +size_t ggml_type_size(enum ggml_type type) { + return GGML_TYPE_SIZE[type]; +} + +size_t ggml_element_size(const struct ggml_tensor * tensor) { + return GGML_TYPE_SIZE[tensor->type]; +} + +static inline bool ggml_is_scalar(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool ggml_is_vector(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->ne[0] == t1->ne[0]) && + (t0->ne[2] == t1->ne[2]) && + (t0->ne[3] == t1->ne[3]); +} + +static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] && + tensor->nb[1] == tensor->nb[0]*tensor->ne[0] && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +static inline bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->ne[0] == t1->ne[0] ) && + (t0->ne[1] == t1->ne[1] ) && + (t0->ne[2] == t1->ne[2] ) && + (t0->ne[3] == t1->ne[3] ); +} + +// check if t1 can be represented as a repeatition of t0 +static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t1->ne[0]%t0->ne[0] == 0) && + (t1->ne[1]%t0->ne[1] == 0) && + (t1->ne[2]%t0->ne[2] == 0) && + (t1->ne[3]%t0->ne[3] == 0); +} + +static inline int ggml_up32(int n) { + return (n + 31) & ~31; +} + +static inline int ggml_up64(int n) { + return (n + 63) & ~63; +} + +static inline int ggml_up(int n, int m) { + // assert m is a power of 2 + GGML_ASSERT((m & (m - 1)) == 0); + return (n + m - 1) & ~(m - 1); +} + +// assert that pointer is aligned to GGML_MEM_ALIGN +#define ggml_assert_aligned(ptr) \ + assert(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0) + +//////////////////////////////////////////////////////////////////////////////// + +struct ggml_context * ggml_init(struct ggml_init_params params) { + // make this function thread safe + ggml_critical_section_start(); + + static bool is_first_call = true; + + if (is_first_call) { + // initialize GELU, EXP and F32 tables + { + const uint64_t t_start = ggml_time_us(); UNUSED(t_start); + + ggml_fp16_t ii; + for (int i = 0; i < (1 << 16); ++i) { + uint16_t ui = i; + memcpy(&ii, &ui, sizeof(ii)); + const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii); + table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); + table_exp_f16[i] = GGML_FP32_TO_FP16(exp(f)); + } + + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); + + GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + } + + // initialize g_state + { + const uint64_t t_start = ggml_time_us(); UNUSED(t_start); + + g_state = (struct ggml_state) { + /*.contexts =*/ { { 0 } }, + }; + + for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) { + g_state.contexts[i].used = false; + } + + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); + + GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + } + + is_first_call = false; + } + + // find non-used context in g_state + struct ggml_context * ctx = NULL; + + for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { + if (!g_state.contexts[i].used) { + g_state.contexts[i].used = true; + ctx = &g_state.contexts[i].context; + + GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i); + break; + } + } + + if (ctx == NULL) { + GGML_PRINT_DEBUG("%s: no unused context found\n", __func__); + + ggml_critical_section_end(); + + return NULL; + } + + *ctx = (struct ggml_context) { + /*.mem_size =*/ params.mem_size, + /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size), + /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, + /*.n_objects =*/ 0, + /*.objects_begin =*/ NULL, + /*.objects_end =*/ NULL, + /*.scratch =*/ { 0, 0, NULL, }, + /*.scratch_save =*/ { 0, 0, NULL, }, + }; + + ggml_assert_aligned(ctx->mem_buffer); + + GGML_PRINT_DEBUG("%s: context initialized\n", __func__); + + ggml_critical_section_end(); + + return ctx; +} + +void ggml_free(struct ggml_context * ctx) { + // make this function thread safe + ggml_critical_section_start(); + + bool found = false; + + for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { + if (&g_state.contexts[i].context == ctx) { + g_state.contexts[i].used = false; + + GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n", + __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size); + + if (ctx->mem_buffer_owned) { + free(ctx->mem_buffer); + } + + found = true; + break; + } + } + + if (!found) { + GGML_PRINT_DEBUG("%s: context not found\n", __func__); + } + + ggml_critical_section_end(); +} + +size_t ggml_used_mem(const struct ggml_context * ctx) { + return ctx->objects_end->offs + ctx->objects_end->size; +} + +size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) { + const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0; + + ctx->scratch = scratch; + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +struct ggml_tensor * ggml_new_tensor_impl( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int* ne, + void* data) { + // always insert objects at the end of the context's memory pool + struct ggml_object * obj_cur = ctx->objects_end; + + const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs; + const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size; + const size_t cur_end = cur_offs + cur_size; + + size_t size_needed = 0; + + if (data == NULL) { + size_needed += GGML_TYPE_SIZE[type]; + for (int i = 0; i < n_dims; i++) { + size_needed *= ne[i]; + } + // align to GGML_MEM_ALIGN + size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN; + } + + char * const mem_buffer = ctx->mem_buffer; + struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end); + + if (ctx->scratch.data == NULL || data != NULL) { + size_needed += sizeof(struct ggml_tensor); + + if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { + GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", + __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size); + assert(false); + return NULL; + } + + *obj_new = (struct ggml_object) { + .offs = cur_end + GGML_OBJECT_SIZE, + .size = size_needed, + .next = NULL, + }; + } else { + if (ctx->scratch.offs + size_needed > ctx->scratch.size) { + GGML_PRINT("%s: not enough space in the scratch memory\n", __func__); + assert(false); + return NULL; + } + + if (cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE > ctx->mem_size) { + GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", + __func__, cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE, ctx->mem_size); + assert(false); + return NULL; + } + + data = (char * const) ctx->scratch.data + ctx->scratch.offs; + + *obj_new = (struct ggml_object) { + .offs = cur_end + GGML_OBJECT_SIZE, + .size = sizeof(struct ggml_tensor), + .next = NULL, + }; + + //printf("scratch offs = %zu, size_needed = %zu\n", ctx->scratch.offs, size_needed); + + ctx->scratch.offs += size_needed; + } + + if (obj_cur != NULL) { + obj_cur->next = obj_new; + } else { + // this is the first object in this context + ctx->objects_begin = obj_new; + } + + ctx->objects_end = obj_new; + + //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size); + + struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offs); + + ggml_assert_aligned(result); + + *result = (struct ggml_tensor) { + /*.type =*/ type, + /*.n_dims =*/ n_dims, + /*.ne =*/ { 1, 1, 1, 1 }, + /*.nb =*/ { 0, 0, 0, 0 }, + /*.op =*/ GGML_OP_NONE, + /*.is_param =*/ false, + /*.grad =*/ NULL, + /*.src0 =*/ NULL, + /*.src1 =*/ NULL, + /*.opt =*/ { NULL }, + /*.n_tasks =*/ 0, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + /*.data =*/ data == NULL ? (void *)(result + 1) : data, + /*.pad =*/ { 0 }, + }; + + ggml_assert_aligned(result->data); + + for (int i = 0; i < n_dims; i++) { + result->ne[i] = ne[i]; + } + + result->nb[0] = GGML_TYPE_SIZE[type]; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; + } + + ctx->n_objects++; + + return result; +} + +struct ggml_tensor * ggml_new_tensor( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int * ne) { + return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL); +} + +struct ggml_tensor * ggml_new_tensor_1d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0) { + return ggml_new_tensor(ctx, type, 1, &ne0); +} + +struct ggml_tensor * ggml_new_tensor_2d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1) { + const int ne[2] = { ne0, ne1 }; + return ggml_new_tensor(ctx, type, 2, ne); +} + +struct ggml_tensor * ggml_new_tensor_3d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1, + int ne2) { + const int ne[3] = { ne0, ne1, ne2 }; + return ggml_new_tensor(ctx, type, 3, ne); +} + +struct ggml_tensor * ggml_new_tensor_4d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1, + int ne2, + int ne3) { + const int ne[4] = { ne0, ne1, ne2, ne3 }; + return ggml_new_tensor(ctx, type, 4, ne); +} + +struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { + ctx->scratch_save = ctx->scratch; + ctx->scratch.data = NULL; + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + + ctx->scratch = ctx->scratch_save; + + ggml_set_i32(result, value); + + return result; +} + +struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) { + ctx->scratch_save = ctx->scratch; + ctx->scratch.data = NULL; + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + + ctx->scratch = ctx->scratch_save; + + ggml_set_f32(result, value); + + return result; +} + +struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) { + return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL); +} + +struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { + memset(tensor->data, 0, ggml_nbytes(tensor)); + return tensor; +} + +struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { + const int n = ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } + + return tensor; +} + +struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { + const int n = ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } + + return tensor; +} + +int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + return ((int8_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + return ((int16_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + return ((int32_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return ((float *)(tensor->data))[i]; + } break; + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } + + return 0.0f; +} + +void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + return ((int8_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + return ((int16_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + return ((int32_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return ((float *)(tensor->data))[i]; + } break; + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } + + return 0.0f; +} + +void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +void * ggml_get_data(const struct ggml_tensor * tensor) { + return tensor->data; +} + +float * ggml_get_data_f32(const struct ggml_tensor * tensor) { + assert(tensor->type == GGML_TYPE_F32); + return (float *)(tensor->data); +} + +struct ggml_tensor * ggml_view_tensor( + struct ggml_context * ctx, + const struct ggml_tensor * src) { + return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data); +} + +//////////////////////////////////////////////////////////////////////////////// + +// ggml_dup + +struct ggml_tensor * ggml_dup_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_DUP; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_dup( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_dup_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_dup_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_dup_impl(ctx, a, true); +} + +// ggml_add + +struct ggml_tensor * ggml_add_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_ADD; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_add_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add_impl(ctx, a, b, true); +} + +// ggml_sub + +struct ggml_tensor * ggml_sub_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SUB; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_sub( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_sub_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_sub_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_sub_impl(ctx, a, b, true); +} + +// ggml_mul + +struct ggml_tensor * ggml_mul_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + if (inplace) { + assert(is_node == false); + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_MUL; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_mul( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_mul_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_mul_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_mul_impl(ctx, a, b, true); +} + +// ggml_div + +struct ggml_tensor * ggml_div_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + if (inplace) { + assert(is_node == false); + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_DIV; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_div( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_div_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_div_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_div_impl(ctx, a, b, true); +} + +// ggml_sqr + +struct ggml_tensor * ggml_sqr_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SQR; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqr_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sqr_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqr_impl(ctx, a, true); +} + +// ggml_sqrt + +struct ggml_tensor * ggml_sqrt_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SQRT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_sqrt( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqrt_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sqrt_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqrt_impl(ctx, a, true); +} + +// ggml_sum + +struct ggml_tensor * ggml_sum( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); + + result->op = GGML_OP_SUM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// ggml_mean + +struct ggml_tensor * ggml_mean( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement + is_node = true; + } + + int ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, ne); + + result->op = GGML_OP_MEAN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// ggml_repeat + +struct ggml_tensor * ggml_repeat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_can_repeat(a, b)); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + if (ggml_are_same_shape(a, b) && !is_node) { + return a; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); + + result->op = GGML_OP_REPEAT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_abs + +struct ggml_tensor * ggml_abs_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_ABS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_abs( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_abs_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_abs_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_abs_impl(ctx, a, true); +} + + +// ggml_sgn + +struct ggml_tensor * ggml_sgn_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SGN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_sgn( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sgn_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sgn_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sgn_impl(ctx, a, true); +} + +// ggml_neg + +struct ggml_tensor * ggml_neg_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_NEG; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_neg( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_neg_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_neg_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_neg_impl(ctx, a, true); +} + +// ggml_step + +struct ggml_tensor * ggml_step_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_STEP; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_step( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_step_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_step_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_step_impl(ctx, a, true); +} + +// ggml_relu + +struct ggml_tensor * ggml_relu_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_RELU; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_relu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_relu_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_relu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_relu_impl(ctx, a, true); +} + +// ggml_gelu + +struct ggml_tensor * ggml_gelu_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_GELU; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_gelu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_gelu_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_gelu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_gelu_impl(ctx, a, true); +} + +// ggml_norm + +struct ggml_tensor * ggml_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + assert(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_NORM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store epsilon here? + + return result; +} + +struct ggml_tensor * ggml_norm( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_norm_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_norm_impl(ctx, a, true); +} + +// ggml_mul_mat + +struct ggml_tensor * ggml_mul_mat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_can_mul_mat(a, b)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); + + result->op = GGML_OP_MUL_MAT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_scale + +struct ggml_tensor * ggml_scale_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_is_scalar(b)); + assert(ggml_is_padded_1d(a)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + assert(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->op = GGML_OP_SCALE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_scale( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_scale_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_scale_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_scale_impl(ctx, a, b, true); +} + +// ggml_cpy + +struct ggml_tensor * ggml_cpy_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + assert(ggml_nelements(a) == ggml_nelements(b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + assert(false); // TODO: implement backward + is_node = true; + } + + // make a view of the destination + struct ggml_tensor * result = ggml_view_tensor(ctx, b); + + result->op = GGML_OP_CPY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_cpy( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_cpy_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_cpy_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_cpy_impl(ctx, a, b, true); +} + +// ggml_reshape + +struct ggml_tensor * ggml_reshape( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_is_contiguous(a)); + assert(ggml_is_contiguous(b)); + assert(ggml_nelements(a) == ggml_nelements(b)); + + bool is_node = false; + + if (a->grad || b->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_reshape_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1) { + assert(ggml_is_contiguous(a)); + assert(ggml_nelements(a) == ne0*ne1); + + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + const int ne[2] = { ne0, ne1 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a->data); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_reshape_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + int ne2) { + assert(ggml_is_contiguous(a)); + assert(ggml_nelements(a) == ne0*ne1*ne2); + + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + const int ne[3] = { ne0, ne1, ne2 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a->data); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// ggml_view_1d + +struct ggml_tensor * ggml_view_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + size_t offset) { + if (a->grad) { + assert(false); // gradient propagation is not supported + } + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset); + + result->op = GGML_OP_VIEW; + result->grad = NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store the offset here? + + return result; +} + +// ggml_view_2d + +struct ggml_tensor * ggml_view_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + size_t nb1, + size_t offset) { + if (a->grad) { + assert(false); // gradient propagation is not supported + } + + const int ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 }; + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset); + + result->nb[1] = nb1; + result->nb[2] = result->nb[1]*ne1; + result->nb[3] = result->nb[2]; + + result->op = GGML_OP_VIEW; + result->grad = NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store the offset here? + + return result; +} + +// ggml_permute + +struct ggml_tensor * ggml_permute( + struct ggml_context * ctx, + struct ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3) { + assert(axis0 >= 0 && axis0 < GGML_MAX_DIMS); + assert(axis1 >= 0 && axis1 < GGML_MAX_DIMS); + assert(axis2 >= 0 && axis2 < GGML_MAX_DIMS); + assert(axis3 >= 0 && axis3 < GGML_MAX_DIMS); + + assert(axis0 != axis1); + assert(axis0 != axis2); + assert(axis0 != axis3); + assert(axis1 != axis2); + assert(axis1 != axis3); + assert(axis2 != axis3); + + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + int ne[GGML_MAX_DIMS]; + int nb[GGML_MAX_DIMS]; + + ne[axis0] = a->ne[0]; + ne[axis1] = a->ne[1]; + ne[axis2] = a->ne[2]; + ne[axis3] = a->ne[3]; + + nb[axis0] = a->nb[0]; + nb[axis1] = a->nb[1]; + nb[axis2] = a->nb[2]; + nb[axis3] = a->nb[3]; + + result->ne[0] = ne[0]; + result->ne[1] = ne[1]; + result->ne[2] = ne[2]; + result->ne[3] = ne[3]; + + result->nb[0] = nb[0]; + result->nb[1] = nb[1]; + result->nb[2] = nb[2]; + result->nb[3] = nb[3]; + + result->op = GGML_OP_PERMUTE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store the permutation here? + + return result; +} + +// ggml_transpose + +struct ggml_tensor * ggml_transpose( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->ne[0] = a->ne[1]; + result->ne[1] = a->ne[0]; + + result->nb[0] = a->nb[1]; + result->nb[1] = a->nb[0]; + + result->op = GGML_OP_TRANSPOSE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// ggml_get_rows + +struct ggml_tensor * ggml_get_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); + + bool is_node = false; + + if (a->grad || b->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + // TODO: implement non F32 return + //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]); + + result->op = GGML_OP_GET_ROWS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_diag_mask_inf + +struct ggml_tensor * ggml_diag_mask_inf( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past) { + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + struct ggml_tensor * b = ggml_new_i32(ctx, n_past); + + result->op = GGML_OP_DIAG_MASK_INF; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_soft_max + +struct ggml_tensor * ggml_soft_max( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->op = GGML_OP_SOFT_MAX; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +// ggml_rope + +struct ggml_tensor * ggml_rope( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode) { + assert(n_past >= 0); + bool is_node = false; + + if (a->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); + ((int32_t *) b->data)[0] = n_past; + ((int32_t *) b->data)[1] = n_dims; + ((int32_t *) b->data)[2] = mode; + + result->op = GGML_OP_ROPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_conv_1d_1s + +struct ggml_tensor * ggml_conv_1d_1s( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_is_matrix(b)); + assert(a->ne[1] == b->ne[1]); + assert(a->ne[3] == 1); + bool is_node = false; + + if (a->grad || b->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + const int ne[4] = { b->ne[0], a->ne[2], 1, 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + + result->op = GGML_OP_CONV_1D_1S; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_conv_1d_2s + +struct ggml_tensor * ggml_conv_1d_2s( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + assert(ggml_is_matrix(b)); + assert(a->ne[1] == b->ne[1]); + assert(a->ne[3] == 1); + bool is_node = false; + + if (a->grad || b->grad) { + assert(false); // TODO: implement backward + is_node = true; + } + + const int ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + + result->op = GGML_OP_CONV_1D_2S; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_flash_attn + +struct ggml_tensor * ggml_flash_attn( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + bool masked) { + assert(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne); + + result->op = GGML_OP_FLASH_ATTN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = q; + result->src1 = k; + result->opt[0] = v; + result->opt[1] = ggml_new_i32(ctx, masked ? 1 : 0); + + return result; +} + +// ggml_flash_ff + +struct ggml_tensor * ggml_flash_ff( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b0, + struct ggml_tensor * b1, + struct ggml_tensor * c0, + struct ggml_tensor * c1) { + assert(ggml_can_mul_mat(b0, a)); + // TODO: more checks + + bool is_node = false; + + if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne); + + result->op = GGML_OP_FLASH_FF; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b0; + result->opt[0] = b1; + result->opt[1] = c0; + result->opt[2] = c1; + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +void ggml_set_param( + struct ggml_context * ctx, + struct ggml_tensor * tensor) { + tensor->is_param = true; + + assert(tensor->grad == NULL); + tensor->grad = ggml_dup_tensor(ctx, tensor); +} + +// ggml_compute_forward_dup + +static void ggml_compute_forward_dup_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_is_contiguous(dst)); + assert(ggml_nelements(dst) == ggml_nelements(src0)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + if (ggml_is_contiguous(src0) && src0->type == dst->type) { + memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]); + return; + } + + if (src0->nb[0] == sizeof(ggml_fp16_t)) { + if (dst->type == GGML_TYPE_F16) { + int id = 0; + const size_t rs = ne00*nb00; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + char * dst_ptr = (char *) dst->data + id*rs; + + memcpy(dst_ptr, src0_ptr, rs); + + id++; + } + } + } + } else if (dst->type == GGML_TYPE_F32) { + int id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + int id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + int id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } +} + +static void ggml_compute_forward_dup_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + if (ggml_is_contiguous(src0) && src0->type == dst->type) { + memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]); + return; + } + + if (src0->nb[0] == sizeof(float)) { + if (dst->type == GGML_TYPE_F32) { + int id = 0; + const size_t rs = ne00*nb00; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + char * dst_ptr = (char *) dst->data + id*rs; + + memcpy(dst_ptr, src0_ptr, rs); + + id++; + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + int id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + int id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + int id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } +} + +static void ggml_compute_forward_dup( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_dup_f16(params, src0, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_dup_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_add + +static void ggml_compute_forward_add_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb10 = src1->nb[0]; + const size_t nb11 = src1->nb[1]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + if (nb10 == sizeof(float)) { + const int j0 = (n/nth)*ith; + const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1); + + for (int j = j0; j < j1; j++) { + ggml_vec_add_f32(nc, + (float *) ((char *) dst->data + j*nb1), + (float *) ((char *) src0->data + j*nb01), + (float *) ((char *) src1->data + j*nb11)); + } + } else { + // src1 is not contiguous + for (int j = ith; j < n; j += nth) { + float * dst_ptr = (float *) ((char *) dst->data + j*nb1); + float * src0_ptr = (float *) ((char *) src0->data + j*nb01); + for (int i = 0; i < nc; i++) { + float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10); + + dst_ptr[i] = src0_ptr[i] + *src1_ptr; + } + } + } +} + +static void ggml_compute_forward_add( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_add_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_sub + +static void ggml_compute_forward_sub_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + assert(src1->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sub_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), + (float *) ((char *) src1->data + i*(src1->nb[1]))); + } +} + +static void ggml_compute_forward_sub( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sub_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_mul + +static void ggml_compute_forward_mul_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + assert(src1->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_mul_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), + (float *) ((char *) src1->data + i*(src1->nb[1]))); + } +} + +static void ggml_compute_forward_mul( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_mul_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_div + +static void ggml_compute_forward_div_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + assert(src1->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_div_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), + (float *) ((char *) src1->data + i*(src1->nb[1]))); + } +} + +static void ggml_compute_forward_div( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_div_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_sqr + +static void ggml_compute_forward_sqr_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sqr_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sqr( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sqr_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_sqrt + +static void ggml_compute_forward_sqrt_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sqrt_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sqrt( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sqrt_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_sum + +static void ggml_compute_forward_sum_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_is_scalar(dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + assert(ggml_is_scalar(dst)); + assert(src0->nb[0] == sizeof(float)); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f32(ne00, + (float *) (dst->data), + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + } + } + } +} + +static void ggml_compute_forward_sum( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sum_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_mean + +static void ggml_compute_forward_mean_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + assert(src0->nb[0] == sizeof(float)); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + + assert(ne0 == 1); + assert(ne1 == ne01); + assert(ne2 == ne02); + assert(ne3 == ne03); + + UNUSED(ne0); + UNUSED(ne1); + UNUSED(ne2); + UNUSED(ne3); + + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f32(ne00, + (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + + *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00; + } + } + } +} + +static void ggml_compute_forward_mean( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_mean_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_repeat + +static void ggml_compute_forward_repeat_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_can_repeat(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // TODO: implement support for rank > 2 tensors + assert(src0->ne[2] == 1); + assert(src0->ne[3] == 1); + assert( dst->ne[2] == 1); + assert( dst->ne[3] == 1); + + const int nc = dst->ne[0]; + const int nr = dst->ne[1]; + const int nc0 = src0->ne[0]; + const int nr0 = src0->ne[1]; + const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat + const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat + + // TODO: support for transposed / permuted tensors + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + // TODO: maybe this is not optimal? + for (int i = 0; i < nrr; i++) { + for (int j = 0; j < ncr; j++) { + for (int k = 0; k < nr0; k++) { + ggml_vec_cpy_f32(nc0, + (float *) ((char *) dst->data + (i*nr0 + k)*( dst->nb[1]) + j*nc0*( dst->nb[0])), + (float *) ((char *) src0->data + ( k)*(src0->nb[1]))); + } + } + } +} + +static void ggml_compute_forward_repeat( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_repeat_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_abs + +static void ggml_compute_forward_abs_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_abs_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_abs( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_abs_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_sgn + +static void ggml_compute_forward_sgn_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sgn_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sgn( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sgn_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_neg + +static void ggml_compute_forward_neg_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_neg_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_neg( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_neg_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_step + +static void ggml_compute_forward_step_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_step_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_step( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_step_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_relu + +static void ggml_compute_forward_relu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_relu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_relu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_relu_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_gelu + +static void ggml_compute_forward_gelu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_gelu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_gelu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gelu_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_norm + +static void ggml_compute_forward_norm_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + const ggml_float eps = 1e-5f; // TODO: make this a parameter + + // TODO: optimize + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float mean = 0.0; + for (int i00 = 0; i00 < ne00; i00++) { + mean += x[i00]; + } + + mean /= ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + ggml_float sum2 = 0.0; + for (int i00 = 0; i00 < ne00; i00++) { + ggml_float v = x[i00] - mean; + y[i00] = v; + sum2 += v*v; + } + + const float scale = 1.0/sqrt(sum2/ne00 + eps); + + ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void ggml_compute_forward_norm( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_norm_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_mul_mat + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +// helper function to determine if it is better to use BLAS or not +// for large matrices, BLAS is faster +static bool ggml_compute_forward_mul_mat_use_blas( + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + UNUSED(src0); + + const int ne10 = src1->ne[0]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + // TODO: find the optimal values for these + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ( + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32) + )) { + //printf("BLAS: %d %d %d\n", ne0, ne1, ne10); + return true; + } + + return false; +} +#endif + +static void ggml_compute_forward_mul_mat_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + assert(ne02 == ne12); + assert(ne03 == ne13); + assert(ne2 == ne12); + assert(ne3 == ne13); + + // TODO: we don't support permuted src0 + assert(nb00 == sizeof(float) || nb01 == sizeof(float)); + + // dst cannot be transposed or permuted + assert(nb0 == sizeof(float)); + assert(nb0 <= nb1); + assert(nb1 <= nb2); + assert(nb2 <= nb3); + + assert(ne0 == ne01); + assert(ne1 == ne11); + assert(ne2 == ne02); + assert(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + // + // nb00 < nb01 - src0 is transposed + // compute by src0 columns + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->ith != 0) { + return; + } + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + const float * x = (float *) (src0->data); + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // zT = y * xT + { + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01); + } + } + } + + //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + + return; + } +#endif + + if (params->type == GGML_TASK_INIT) { + if (nb01 >= nb00) { + return; + } + + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + if (nb01 >= nb00) { + return; + } + + // TODO: fix this memset (wsize is overestimated) + //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth); + + float * const wdata = params->wdata; + + // cols per thread + const int dc = (ne + nth - 1)/nth; + + // col range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, ne); + + ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0); + + for (int k = 1; k < nth; k++) { + ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0); + } + + return; + } + + if (nb01 >= nb00) { + // TODO: do not support transposed src1 + assert(nb10 == sizeof(float)); + + // parallelize by src0 rows using ggml_vec_dot_f32 + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + for (int ic = 0; ic < ne11; ++ic) { + // src1 indices + const int i13 = i03; + const int i12 = i02; + const int i11 = ic; + + // dst indices + const int i0 = i01; + const int i1 = i11; + const int i2 = i02; + const int i3 = i03; + + ggml_vec_dot_f32(ne00, + (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)), + (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13))); + } + } + } else { + // parallelize by src1 columns using ggml_vec_mad_f32 + // each thread has its own work data + // during FINALIZE we accumulate all work data into dst + + // total columns in src1 + const int nc = ne10; + + // columns per thread + const int dc = (nc + nth - 1)/nth; + + // column range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, nc); + + // work data for thread + const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; + float * const wdata = params->wdata; + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + for (int ic = ic0; ic < ic1; ++ic) { + // src1 indices + const int i10 = ic; + + // src0 indices + const int i03 = i13; + const int i02 = i12; + const int i00 = ic; + + // dst indices + const int i1 = i11; + const int i2 = i12; + const int i3 = i13; + + assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); + + ggml_vec_mad_f32(ne01, + (float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0), + (float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)), + *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13))); + } + } + } + } + } + + //int64_t t1 = ggml_perf_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_compute_forward_mul_mat_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // TODO: we don't support permuted src0 + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + // + // nb00 < nb01 - src0 is transposed + // compute by src0 columns + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->ith != 0) { + return; + } + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + float * const wdata = params->wdata; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + { + int id = 0; + for (int i01 = 0; i01 < ne01; ++i01) { + for (int i00 = 0; i00 < ne00; ++i00) { + wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); + } + } + } + + const float * x = wdata; + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + // float * z = wdata + ne00*ne01; + + // z = x * yT + //{ + // cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + // ne01, ne11, ne00, + // 1.0f, x, ne00, + // y, ne00, + // 0.0f, z, ne11); + //} + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // transpose z + //for (int j = 0; j < ne11; ++j) { + // for (int i = 0; i < ne01; ++i) { + // d[j*ne01 + i] = z[i*ne11 + j]; + // } + //} + + { +#if 1 + // zT = y * xT + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne00, + x, ne00, + 0.0f, d, ne01); +#else + // zT = (xT * y)T + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, + ne01, ne11, ne10, + 1.0f, x, ne00, + y, ne00, + 0.0f, d, ne01); +#endif + } + } + } + + //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + + return; + } +#endif + + if (params->type == GGML_TASK_INIT) { + if (nb01 >= nb00) { + ggml_fp16_t * const wdata = params->wdata; + + int id = 0; + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + for (int i10 = 0; i10 < ne10; ++i10) { + wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); + } + } + } + } + + GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize); + + return; + } + + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + if (nb01 >= nb00) { + return; + } + + // TODO: fix this memset (wsize is overestimated) + //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth); + + ggml_fp16_t * const wdata = params->wdata; + + // cols per thread + const int dc = (ne + nth - 1)/nth; + + // col range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, ne); + + for (int i = ic0; i < ic1; ++i) { + ((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]); + } + + for (int k = 1; k < nth; k++) { + for (int i = ic0; i < ic1; ++i) { + ((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]); + } + } + + return; + } + + if (nb01 >= nb00) { + // fp16 -> half the size, so divide by 2 + // TODO: do not support transposed src1 + assert(nb10/2 == sizeof(ggml_fp16_t)); + + // parallelize by src0 rows using ggml_vec_dot_f16 + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + ggml_fp16_t * wdata = params->wdata; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int i13 = i03; + const int i12 = i02; + + const int i0 = i01; + const int i2 = i02; + const int i3 = i03; + + ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00; + + float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + + assert(ne00 % 32 == 0); + + for (int ic = 0; ic < ne11; ++ic) { + ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); + } + } + } else { + // parallelize by src1 columns using ggml_vec_mad_f16 + // each thread has its own work data + // during FINALIZE we accumulate all work data into dst + + // total columns in src1 + const int nc = ne10; + + // columns per thread + const int dc = (nc + nth - 1)/nth; + + // column range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, nc); + + // work data for thread + const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; + ggml_fp16_t * const wdata = params->wdata; + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + // dst indices + const int i1 = i11; + const int i2 = i12; + const int i3 = i13; + + ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0; + + for (int ic = ic0; ic < ic1; ++ic) { + // src1 indices + const int i10 = ic; + + // src0 indices + const int i03 = i13; + const int i02 = i12; + const int i00 = ic; + + assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); + + ggml_fp16_t * src0_col = (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)); + float src1_val = * (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + + ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val); + } + } + } + } + } + + //int64_t t1 = ggml_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_compute_forward_mul_mat( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_mul_mat_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_scale + +static void ggml_compute_forward_scale_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // scale factor + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), v); + } +} + +static void ggml_compute_forward_scale( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_scale_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_cpy + +static void ggml_compute_forward_cpy( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + ggml_compute_forward_dup(params, src0, dst); +} + +// ggml_compute_forward_reshape + +static void ggml_compute_forward_reshape( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + // NOP + UNUSED(params); + UNUSED(src0); + UNUSED(dst); +} + +// ggml_compute_forward_view + +static void ggml_compute_forward_view( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// ggml_compute_forward_permute + +static void ggml_compute_forward_permute( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// ggml_compute_forward_transpose + +static void ggml_compute_forward_transpose( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// ggml_compute_forward_get_rows + +static void ggml_compute_forward_get_rows_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == sizeof(ggml_fp16_t)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + for (int j = 0; j < nc; ++j) { + ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j]; + ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); + } + } +} + +static void ggml_compute_forward_get_rows_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i*dst->nb[1]), + (float *) ((char *) src0->data + r*src0->nb[1])); + } +} + +static void ggml_compute_forward_get_rows( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_get_rows_f16(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_get_rows_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_diag_mask_inf + +static void ggml_compute_forward_diag_mask_inf_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(src1->type == GGML_TYPE_I32); + assert(ggml_nelements(src1) == 1); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) src1->data)[0]; + + // TODO: handle transposed/permuted matrices + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + const int nr = src0->ne[1]; + const int nz = n/nr; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int k = 0; k < nz; k++) { + for (int j = 0; j < nr; j++) { + for (int i = n_past; i < nc; i++) { + if (i > n_past + j) { + *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = -INFINITY; + } + } + } + } +} + +static void ggml_compute_forward_diag_mask_inf( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_soft_max + +static void ggml_compute_forward_soft_max_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float *p = (float *)((char *) dst->data + i1*dst->nb[1]); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(p[i])); + } +#endif + + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, p); + + ggml_float sum = 0.0; + + uint16_t scvt; + for (int i = 0; i < nc; i++) { + if (p[i] == -INFINITY) { + p[i] = 0.0f; + } else { + //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max); + ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max); + memcpy(&scvt, &s, sizeof(scvt)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); + sum += val; + p[i] = val; + } + } + + assert(sum > 0.0f); + + sum = 1.0/sum; + ggml_vec_scale_f32(nc, p, sum); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(p[i])); + assert(!isinf(p[i])); + } +#endif + } +} + +static void ggml_compute_forward_soft_max( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_soft_max_f32(params, src0, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_rope + +static void ggml_compute_forward_rope_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(src1->type == GGML_TYPE_I32); + assert(ggml_nelements(src1) == 3); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) src1->data)[0]; + const int n_dims = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; + + //const int ne0 = src0->ne[0]; + const int ne1 = src0->ne[1]; + const int ne2 = src0->ne[2]; + const int ne3 = src0->ne[3]; + + const int nb0 = src0->nb[0]; + const int nb1 = src0->nb[1]; + const int nb2 = src0->nb[2]; + const int nb3 = src0->nb[3]; + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + assert(nb0 == sizeof(float)); + + // TODO: optimize + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) { + const int p = (mode == 0 ? n_past + i2 : i2); + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < n_dims; i0 += 2) { + const double theta = pow(10000.0, ((double)-i0)/n_dims); + + const double cos_theta = cos(p*theta); + const double sin_theta = sin(p*theta); + + const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + double x0 = src[0]; + double x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } + } + } + } +} + +static void ggml_compute_forward_rope( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rope_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_conv_1d_1s + +static void ggml_compute_forward_conv_1d_1s_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + //const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + //const int ne12 = src1->ne[2]; + //const int ne13 = src1->ne[3]; + + //const int ne0 = dst->ne[0]; + //const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + //const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + //const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + //const int nb12 = src1->nb[2]; + //const int nb13 = src1->nb[3]; + + //const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + //const int nb2 = dst->nb[2]; + //const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = ggml_up32(ne01); + + GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); + ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; + for (int i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00; + + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + ggml_fp16_t * dst_data = wdata; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]); + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int i0 = 0; i0 < ne10; ++i0) { + dst_data[i0] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + ggml_vec_dot_f16(ew0, &v, + (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0] += v; + } + } + } +} + +static void ggml_compute_forward_conv_1d_1s_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + //const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + //const int ne12 = src1->ne[2]; + //const int ne13 = src1->ne[3]; + + //const int ne0 = dst->ne[0]; + //const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + //const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + //const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + //const int nb12 = src1->nb[2]; + //const int nb13 = src1->nb[3]; + + //const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + //const int nb2 = dst->nb[2]; + //const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = ggml_up32(ne01); + + GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + float * const wdata = (float *) params->wdata + 0; + + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); + float * dst_data = wdata + i02*ew0*ne00; + for (int i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + float * const wdata = (float *) params->wdata + ne02*ew0*ne00; + + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + float * dst_data = wdata; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = src[i10]; + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int i0 = 0; i0 < ne10; ++i0) { + dst_data[i0] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + ggml_vec_dot_f32(ew0, &v, + (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0] += v; + } + } + } +} + +static void ggml_compute_forward_conv_1d_1s( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_1d_1s_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_conv_1d_2s + +static void ggml_compute_forward_conv_1d_2s_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + //const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + //const int ne12 = src1->ne[2]; + //const int ne13 = src1->ne[3]; + + //const int ne0 = dst->ne[0]; + //const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + //const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + //const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + //const int nb12 = src1->nb[2]; + //const int nb13 = src1->nb[3]; + + //const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + //const int nb2 = dst->nb[2]; + //const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = ggml_up32(ne01); + + GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); + ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; + for (int i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00; + + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + ggml_fp16_t * dst_data = wdata; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]); + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int i0 = 0; i0 < ne10; i0 += 2) { + dst_data[i0/2] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + ggml_vec_dot_f16(ew0, &v, + (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0/2] += v; + } + } + } +} + +static void ggml_compute_forward_conv_1d_2s_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + //const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + //const int ne12 = src1->ne[2]; + //const int ne13 = src1->ne[3]; + + //const int ne0 = dst->ne[0]; + //const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + //const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + //const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + //const int nb12 = src1->nb[2]; + //const int nb13 = src1->nb[3]; + + //const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + //const int nb2 = dst->nb[2]; + //const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = ggml_up32(ne01); + + GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + float * const wdata = (float *) params->wdata + 0; + + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); + float * dst_data = wdata + i02*ew0*ne00; + for (int i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + float * const wdata = (float *) params->wdata + ne02*ew0*ne00; + + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + float * dst_data = wdata; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = src[i10]; + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int i0 = 0; i0 < ne10; i0 += 2) { + dst_data[i0/2] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + ggml_vec_dot_f32(ew0, &v, + (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0/2] += v; + } + } + } +} + +static void ggml_compute_forward_conv_1d_2s( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_1d_2s_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_flash_attn + +static void ggml_compute_forward_flash_attn_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int neq0 = q->ne[0]; + const int neq1 = q->ne[1]; + const int neq2 = q->ne[2]; + const int neq3 = q->ne[3]; + + const int nek0 = k->ne[0]; + const int nek1 = k->ne[1]; + //const int nek2 = k->ne[2]; + //const int nek3 = k->ne[3]; + + //const int nev0 = v->ne[0]; + const int nev1 = v->ne[1]; + //const int nev2 = v->ne[2]; + //const int nev3 = v->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + + const int nbk0 = k->nb[0]; + const int nbk1 = k->nb[1]; + const int nbk2 = k->nb[2]; + const int nbk3 = k->nb[3]; + + const int nbq0 = q->nb[0]; + const int nbq1 = q->nb[1]; + const int nbq2 = q->nb[2]; + const int nbq3 = q->nb[3]; + + const int nbv0 = v->nb[0]; + const int nbv1 = v->nb[1]; + const int nbv2 = v->nb[2]; + const int nbv3 = v->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int D = neq0; + const int N = neq1; + const int P = nek1 - N; + const int M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(float)); + GGML_ASSERT(nbk0 == sizeof(float)); + GGML_ASSERT(nbv0 == sizeof(float)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0/sqrt((double) D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + for (int ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f32(neq0, + S + i1, + (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (masked) { + for (int i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = -INFINITY; + } + } + } + + // softmax + { + float max = -INFINITY; + ggml_vec_max_f32(M, &max, S); + + float sum = 0.0f; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + ggml_vec_sum_f32(Mup, &sum, S); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + float * SS = S + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (SS[j] == -INFINITY) { + SS[j] = 0.0f; + } else { + ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); + sump[j] += val; + SS[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0f); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + +#ifndef NDEBUG + for (int i = 0; i < M; ++i) { + assert(!isnan(S[i])); + assert(!isinf(S[i])); + } +#endif + } + + for (int ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_dot_f32(nek1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S); + } + } +} + +static void ggml_compute_forward_flash_attn_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int neq0 = q->ne[0]; + const int neq1 = q->ne[1]; + const int neq2 = q->ne[2]; + const int neq3 = q->ne[3]; + + const int nek0 = k->ne[0]; + const int nek1 = k->ne[1]; + //const int nek2 = k->ne[2]; + //const int nek3 = k->ne[3]; + + //const int nev0 = v->ne[0]; + const int nev1 = v->ne[1]; + //const int nev2 = v->ne[2]; + //const int nev3 = v->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + + const int nbk0 = k->nb[0]; + const int nbk1 = k->nb[1]; + const int nbk2 = k->nb[2]; + const int nbk3 = k->nb[3]; + + const int nbq0 = q->nb[0]; + const int nbq1 = q->nb[1]; + const int nbq2 = q->nb[2]; + const int nbq3 = q->nb[3]; + + const int nbv0 = v->nb[0]; + const int nbv1 = v->nb[1]; + const int nbv2 = v->nb[2]; + const int nbv3 = v->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int D = neq0; + const int N = neq1; + const int P = nek1 - N; + const int M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0/sqrt((double) D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { + for (int ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16(neq0, + S + i1, + (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } else { + for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16_unroll(neq0, nbk1, + S + i1, + ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (masked) { + for (int i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = -INFINITY; + } + } + } + + // softmax + { + float max = -INFINITY; + ggml_vec_max_f32(M, &max, S); + + float sum = 0.0f; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + ggml_vec_sum_f32(Mup, &sum, S); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + float * SS = S + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (SS[j] == -INFINITY) { + SS[j] = 0.0f; + } else { + ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); + sump[j] += val; + SS[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0f); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + +#ifndef NDEBUG + for (int i = 0; i < M; ++i) { + assert(!isnan(S[i])); + assert(!isinf(S[i])); + } +#endif + } + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); + + for (int i = 0; i < M; i++) { + S16[i] = GGML_FP32_TO_FP16(S[i]); + } + + if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { + for (int ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_dot_f16(nek1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S16); + } + } else { + for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_dot_f16_unroll(nek1, nbv1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S16); + } + } + } +} + +static void ggml_compute_forward_flash_attn( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + switch (q->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_attn_f16(params, q, k, v, masked, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_flash_ff + +static void ggml_compute_forward_flash_ff_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, // F16 + const struct ggml_tensor * b0, // F16 fc_w + const struct ggml_tensor * b1, // F32 fc_b + const struct ggml_tensor * c0, // F16 proj_w + const struct ggml_tensor * c1, // F32 proj_b + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int nea0 = a->ne[0]; + const int nea1 = a->ne[1]; + const int nea2 = a->ne[2]; + const int nea3 = a->ne[3]; + + const int neb00 = b0->ne[0]; + const int neb01 = b0->ne[1]; + //const int neb02 = b0->ne[2]; + //const int neb03 = b0->ne[3]; + + const int neb10 = b1->ne[0]; + const int neb11 = b1->ne[1]; + //const int neb12 = b1->ne[2]; + //const int neb13 = b1->ne[3]; + + const int nec00 = c0->ne[0]; + const int nec01 = c0->ne[1]; + //const int nec02 = c0->ne[2]; + //const int nec03 = c0->ne[3]; + + const int nec10 = c1->ne[0]; + const int nec11 = c1->ne[1]; + //const int nec12 = c1->ne[2]; + //const int nec13 = c1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + + const int nba0 = a->nb[0]; + const int nba1 = a->nb[1]; + const int nba2 = a->nb[2]; + const int nba3 = a->nb[3]; + + const int nbb00 = b0->nb[0]; + const int nbb01 = b0->nb[1]; + const int nbb02 = b0->nb[2]; + const int nbb03 = b0->nb[3]; + + const int nbb10 = b1->nb[0]; + //const int nbb11 = b1->nb[1]; + //const int nbb12 = b1->nb[2]; + //const int nbb13 = b1->nb[3]; + + const int nbc00 = c0->nb[0]; + const int nbc01 = c0->nb[1]; + const int nbc02 = c0->nb[2]; + const int nbc03 = c0->nb[3]; + + const int nbc10 = c1->nb[0]; + //const int nbc11 = c1->nb[1]; + //const int nbc12 = c1->nb[2]; + //const int nbc13 = c1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int D = nea0; + //const int N = nea1; + const int M = neb01; + + GGML_ASSERT(ne0 == nea0); + GGML_ASSERT(ne1 == nea1); + GGML_ASSERT(ne2 == nea2); + + GGML_ASSERT(nba0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbb10 == sizeof(float)); + GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbc10 == sizeof(float)); + + GGML_ASSERT(neb00 == D); + GGML_ASSERT(neb01 == M); + GGML_ASSERT(neb10 == M); + GGML_ASSERT(neb11 == 1); + + GGML_ASSERT(nec00 == M); + GGML_ASSERT(nec01 == D); + GGML_ASSERT(nec10 == D); + GGML_ASSERT(nec11 == 1); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by a rows using ggml_vec_dot_f32 + + // total rows in a + const int nr = nea1*nea2*nea3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // a indices + const int ia3 = ir/(nea2*nea1); + const int ia2 = (ir - ia3*nea2*nea1)/nea1; + const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1); + + float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32); + + for (int ic = 0; ic < neb01; ++ic) { + // b0 indices + const int ib03 = ia3; + const int ib02 = ia2; + const int ib01 = ic; + + // S indices + const int i1 = ib01; + + ggml_vec_dot_f16(nea0, + S + i1, + (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), + (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3))); + } + + ggml_vec_add_f32(neb01, S, S, (float *) b1->data); + //ggml_vec_gelu_f32(neb01, S, S); + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); + + for (int i = 0; i < M; i++) { + S16[i] = GGML_FP32_TO_FP16(S[i]); + } + + ggml_vec_gelu_f16(neb01, S16, S16); + + { + // dst indices + const int i1 = ia1; + const int i2 = ia2; + const int i3 = ia3; + + for (int ic = 0; ic < nec01; ++ic) { + + ggml_vec_dot_f16(neb01, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), + S16); + } + + ggml_vec_add_f32(nec01, + (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), + (float *) c1->data); + } + } +} + +static void ggml_compute_forward_flash_ff( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + const struct ggml_tensor * b0, + const struct ggml_tensor * b1, + const struct ggml_tensor * c0, + const struct ggml_tensor * c1, + struct ggml_tensor * dst) { + switch (b0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_ff_f16(params, a, b0, b1, c0, c1, dst); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(false); // TODO + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +///////////////////////////////// + +static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { + assert(params); + + switch (tensor->op) { + case GGML_OP_DUP: + { + ggml_compute_forward_dup(params, tensor->src0, tensor); + } break; + case GGML_OP_ADD: + { + ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_SUB: + { + ggml_compute_forward_sub(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_MUL: + { + ggml_compute_forward_mul(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_DIV: + { + ggml_compute_forward_div(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_SQR: + { + ggml_compute_forward_sqr(params, tensor->src0, tensor); + } break; + case GGML_OP_SQRT: + { + ggml_compute_forward_sqrt(params, tensor->src0, tensor); + } break; + case GGML_OP_SUM: + { + ggml_compute_forward_sum(params, tensor->src0, tensor); + } break; + case GGML_OP_MEAN: + { + ggml_compute_forward_mean(params, tensor->src0, tensor); + } break; + case GGML_OP_REPEAT: + { + ggml_compute_forward_repeat(params, tensor->src0, tensor); + } break; + case GGML_OP_ABS: + { + ggml_compute_forward_abs(params, tensor->src0, tensor); + } break; + case GGML_OP_SGN: + { + ggml_compute_forward_sgn(params, tensor->src0, tensor); + } break; + case GGML_OP_NEG: + { + ggml_compute_forward_neg(params, tensor->src0, tensor); + } break; + case GGML_OP_STEP: + { + ggml_compute_forward_step(params, tensor->src0, tensor); + } break; + case GGML_OP_RELU: + { + ggml_compute_forward_relu(params, tensor->src0, tensor); + } break; + case GGML_OP_GELU: + { + ggml_compute_forward_gelu(params, tensor->src0, tensor); + } break; + case GGML_OP_NORM: + { + ggml_compute_forward_norm(params, tensor->src0, tensor); + } break; + case GGML_OP_MUL_MAT: + { + ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_SCALE: + { + ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_CPY: + { + ggml_compute_forward_cpy(params, tensor->src0, tensor); + } break; + case GGML_OP_RESHAPE: + { + ggml_compute_forward_reshape(params, tensor->src0, tensor); + } break; + case GGML_OP_VIEW: + { + ggml_compute_forward_view(params, tensor->src0); + } break; + case GGML_OP_PERMUTE: + { + ggml_compute_forward_permute(params, tensor->src0); + } break; + case GGML_OP_TRANSPOSE: + { + ggml_compute_forward_transpose(params, tensor->src0); + } break; + case GGML_OP_GET_ROWS: + { + ggml_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_DIAG_MASK_INF: + { + ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_SOFT_MAX: + { + ggml_compute_forward_soft_max(params, tensor->src0, tensor); + } break; + case GGML_OP_ROPE: + { + ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_CONV_1D_1S: + { + ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_CONV_1D_2S: + { + ggml_compute_forward_conv_1d_2s(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_FLASH_ATTN: + { + int32_t t = ggml_get_i32_1d(tensor->opt[1], 0); + GGML_ASSERT(t == 0 || t == 1); + bool masked = t != 0; + ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor); + } break; + case GGML_OP_FLASH_FF: + { + ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor); + } break; + case GGML_OP_NONE: + { + // nop + } break; + case GGML_OP_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +//////////////////////////////////////////////////////////////////////////////// + +static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) { + struct ggml_tensor * src0 = tensor->src0; + struct ggml_tensor * src1 = tensor->src1; + + switch (tensor->op) { + case GGML_OP_DUP: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + } break; + case GGML_OP_ADD: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + if (src1->grad) { + src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace); + } + } break; + case GGML_OP_SUB: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + if (src1->grad) { + src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace); + } + } break; + case GGML_OP_MUL: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_mul(ctx, src1, tensor->grad), + inplace); + } + if (src1->grad) { + src1->grad = + ggml_add_impl(ctx, + src1->grad, + ggml_mul(ctx, src0, tensor->grad), + inplace); + } + } break; + case GGML_OP_DIV: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_div(ctx, tensor->grad, src1), + inplace); + } + if (src1->grad) { + src1->grad = + ggml_sub_impl(ctx, + src1->grad, + ggml_mul(ctx, + tensor->grad, + ggml_div(ctx, tensor, src1)), + inplace); + } + } break; + case GGML_OP_SQR: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_mul(ctx, + ggml_mul(ctx, src0, tensor->grad), + ggml_repeat(ctx, ggml_new_f32(ctx, 2.0f), src0)), + inplace); + } + } break; + case GGML_OP_SQRT: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_div(ctx, + ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor), + tensor), + inplace); + } + } break; + case GGML_OP_SUM: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_repeat(ctx, tensor->grad, src0->grad), + inplace); + } + } break; + case GGML_OP_MEAN: + { + assert(false); // TODO: implement + } break; + case GGML_OP_REPEAT: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_sum(ctx, tensor->grad), + inplace); + } + } break; + case GGML_OP_ABS: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_mul(ctx, + ggml_sgn(ctx, src0), + tensor->grad), + inplace); + } + } break; + case GGML_OP_SGN: + { + if (src0->grad) { + // noop + } + } break; + case GGML_OP_NEG: + { + if (src0->grad) { + src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace); + } + } break; + case GGML_OP_STEP: + { + if (src0->grad) { + // noop + } + } break; + case GGML_OP_RELU: + { + if (src0->grad) { + src0->grad = ggml_sub_impl(ctx, + src0->grad, + ggml_mul(ctx, + ggml_step(ctx, src0), + tensor->grad), + inplace); + } + } break; + case GGML_OP_GELU: + { + assert(false); // TODO: not implemented + } break; + case GGML_OP_NORM: + { + assert(false); // TODO: not implemented + } break; + case GGML_OP_MUL_MAT: + { + if (src0->grad) { + // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad); + assert(false); + } + if (src1->grad) { + src1->grad = + ggml_add_impl(ctx, + src1->grad, + // TODO: fix transpose, the node will break the graph connections + ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad), + inplace); + } + } break; + case GGML_OP_SCALE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CPY: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_RESHAPE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_VIEW: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_PERMUTE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_TRANSPOSE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_GET_ROWS: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_DIAG_MASK_INF: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_SOFT_MAX: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_ROPE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CONV_1D_1S: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CONV_1D_2S: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_FLASH_ATTN: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_FLASH_FF: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_NONE: + { + // nop + } break; + case GGML_OP_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { + if (node->grad == NULL) { + // this usually happens when we generate intermediate nodes from constants in the backward pass + // it can also happen during forward pass, if the user performs computations with constants + if (node->op != GGML_OP_NONE) { + //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op); + } + } + + // check if already visited + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i] == node) { + return; + } + } + + for (int i = 0; i < cgraph->n_leafs; i++) { + if (cgraph->leafs[i] == node) { + return; + } + } + + if (node->src0) { + ggml_visit_parents(cgraph, node->src0); + } + + if (node->src1) { + ggml_visit_parents(cgraph, node->src1); + } + + for (int i = 0; i < GGML_MAX_OPT; ++i) { + if (node->opt[i]) { + ggml_visit_parents(cgraph, node->opt[i]); + } + } + + if (node->op == GGML_OP_NONE && node->grad == NULL) { + // reached a leaf node, not part of the gradient graph (e.g. a constant) + assert(cgraph->n_leafs < GGML_MAX_NODES); + + cgraph->leafs[cgraph->n_leafs] = node; + cgraph->n_leafs++; + } else { + assert(cgraph->n_nodes < GGML_MAX_NODES); + + cgraph->nodes[cgraph->n_nodes] = node; + cgraph->grads[cgraph->n_nodes] = node->grad; + cgraph->n_nodes++; + } +} + +static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { + if (!expand) { + cgraph->n_nodes = 0; + cgraph->n_leafs = 0; + } + + const int n0 = cgraph->n_nodes; + UNUSED(n0); + + ggml_visit_parents(cgraph, tensor); + + const int n_new = cgraph->n_nodes - n0; + GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); + + if (n_new > 0) { + // the last added node should always be starting point + assert(cgraph->nodes[cgraph->n_nodes - 1] == tensor); + } +} + +void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { + ggml_build_forward_impl(cgraph, tensor, true); +} + +struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) { + struct ggml_cgraph result = { + /*.n_nodes =*/ 0, + /*.n_leafs =*/ 0, + /*.n_threads =*/ 0, + /*.work_size =*/ 0, + /*.work =*/ NULL, + /*.nodes =*/ { NULL }, + /*.grads =*/ { NULL }, + /*.leafs =*/ { NULL }, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + }; + + ggml_build_forward_impl(&result, tensor, false); + + return result; +} + +struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) { + struct ggml_cgraph result = *gf; + + assert(gf->n_nodes > 0); + + // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph + if (keep) { + for (int i = 0; i < gf->n_nodes; i++) { + struct ggml_tensor * node = gf->nodes[i]; + + if (node->grad) { + node->grad = ggml_dup_tensor(ctx, node); + gf->grads[i] = node->grad; + } + } + } + + for (int i = gf->n_nodes - 1; i >= 0; i--) { + struct ggml_tensor * node = gf->nodes[i]; + + // because we detached the grad nodes from the original graph, we can afford inplace operations + if (node->grad) { + ggml_compute_backward(ctx, node, keep); + } + } + + for (int i = gf->n_nodes - 1; i >= 0; i--) { + struct ggml_tensor * node = gf->nodes[i]; + + if (node->is_param) { + GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); + ggml_build_forward_impl(&result, node->grad, true); + } + } + + return result; +} + +// +// thread data +// +// synchronization is done via busy loops +// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops +// + +#ifdef __APPLE__ + +//#include +// +//typedef os_unfair_lock ggml_lock_t; +// +//#define ggml_lock_init(x) UNUSED(x) +//#define ggml_lock_destroy(x) UNUSED(x) +//#define ggml_lock_lock os_unfair_lock_lock +//#define ggml_lock_unlock os_unfair_lock_unlock +// +//#define GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT + +typedef int ggml_lock_t; + +#define ggml_lock_init(x) UNUSED(x) +#define ggml_lock_destroy(x) UNUSED(x) +#define ggml_lock_lock(x) UNUSED(x) +#define ggml_lock_unlock(x) UNUSED(x) + +#define GGML_LOCK_INITIALIZER 0 + +typedef pthread_t ggml_thread_t; + +#define ggml_thread_create pthread_create +#define ggml_thread_join pthread_join + +#else + +//typedef pthread_spinlock_t ggml_lock_t; + +//#define ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE) +//#define ggml_lock_destroy pthread_spin_destroy +//#define ggml_lock_lock pthread_spin_lock +//#define ggml_lock_unlock pthread_spin_unlock + +typedef int ggml_lock_t; + +#define ggml_lock_init(x) UNUSED(x) +#define ggml_lock_destroy(x) UNUSED(x) +#define ggml_lock_lock(x) UNUSED(x) +#define ggml_lock_unlock(x) UNUSED(x) + +#define GGML_LOCK_INITIALIZER 0 + +typedef pthread_t ggml_thread_t; + +#define ggml_thread_create pthread_create +#define ggml_thread_join pthread_join + +#endif + +struct ggml_compute_state_shared { + ggml_lock_t spin; + + int n_threads; + + // synchronization primitives + atomic_int n_ready; + atomic_bool has_work; + atomic_bool stop; // stop all threads +}; + +struct ggml_compute_state { + ggml_thread_t thrd; + + struct ggml_compute_params params; + struct ggml_tensor * node; + + struct ggml_compute_state_shared * shared; +}; + +static thread_ret_t ggml_graph_compute_thread(void * data) { + struct ggml_compute_state * state = (struct ggml_compute_state *) data; + + const int n_threads = state->shared->n_threads; + + while (true) { + if (atomic_fetch_add(&state->shared->n_ready, 1) == n_threads - 1) { + atomic_store(&state->shared->has_work, false); + } else { + while (atomic_load(&state->shared->has_work)) { + if (atomic_load(&state->shared->stop)) { + return 0; + } + ggml_lock_lock (&state->shared->spin); + ggml_lock_unlock(&state->shared->spin); + } + } + + atomic_fetch_sub(&state->shared->n_ready, 1); + + // wait for work + while (!atomic_load(&state->shared->has_work)) { + if (atomic_load(&state->shared->stop)) { + return 0; + } + ggml_lock_lock (&state->shared->spin); + ggml_lock_unlock(&state->shared->spin); + } + + // check if we should stop + if (atomic_load(&state->shared->stop)) { + break; + } + + if (state->node) { + if (state->params.ith < state->params.nth) { + ggml_compute_forward(&state->params, state->node); + } + + state->node = NULL; + } else { + break; + } + } + + return 0; +} + +void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { + if (cgraph->n_threads <= 0) { + cgraph->n_threads = 8; + } + + const int n_threads = cgraph->n_threads; + + struct ggml_compute_state_shared state_shared = { + /*.spin =*/ GGML_LOCK_INITIALIZER, + /*.n_threads =*/ n_threads, + /*.n_ready =*/ 0, + /*.has_work =*/ false, + /*.stop =*/ false, + }; + struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL; + + // create thread pool + if (n_threads > 1) { + ggml_lock_init(&state_shared.spin); + + atomic_store(&state_shared.has_work, true); + + for (int j = 0; j < n_threads - 1; j++) { + workers[j] = (struct ggml_compute_state) { + .thrd = 0, + .params = { + .type = GGML_TASK_COMPUTE, + .ith = j + 1, + .nth = n_threads, + .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, + .wdata = cgraph->work ? cgraph->work->data : NULL, + }, + .node = NULL, + .shared = &state_shared, + }; + + int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); + assert(rc == 0); + UNUSED(rc); + } + } + + // initialize tasks + work buffer + { + size_t work_size = 0; + + // thread scheduling for the different operations + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + switch (node->op) { + case GGML_OP_DUP: + { + node->n_tasks = 1; + } break; + case GGML_OP_ADD: + { + node->n_tasks = n_threads; + } break; + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SUM: + case GGML_OP_MEAN: + case GGML_OP_REPEAT: + case GGML_OP_ABS: + case GGML_OP_SGN: + case GGML_OP_NEG: + case GGML_OP_STEP: + case GGML_OP_RELU: + { + node->n_tasks = 1; + } break; + case GGML_OP_GELU: + { + node->n_tasks = n_threads; + } break; + case GGML_OP_NORM: + { + node->n_tasks = n_threads; + } break; + case GGML_OP_MUL_MAT: + { + node->n_tasks = n_threads; + + // TODO: use different scheduling for different matrix sizes + //const int nr0 = ggml_nrows(node->src0); + //const int nr1 = ggml_nrows(node->src1); + + //node->n_tasks = MIN(n_threads, MAX(1, nr0/128)); + //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks); + + size_t cur = 0; + + // TODO: better way to determine if the matrix is transposed + if (node->src0->nb[1] < node->src0->nb[0]) { + cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1) + } else { + if (node->src0->type == GGML_TYPE_F16 && + node->src1->type == GGML_TYPE_F32) { +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + node->n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]); + //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]); + //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); + //printf("cur = %zu\n", cur); + } else { + cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1); + } +#else + cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1); +#endif + } else if (node->src0->type == GGML_TYPE_F32 && + node->src1->type == GGML_TYPE_F32) { + cur = 0; + } else { + GGML_ASSERT(false); + } + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_SCALE: + { + node->n_tasks = n_threads; + } break; + case GGML_OP_CPY: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_GET_ROWS: + case GGML_OP_DIAG_MASK_INF: + { + node->n_tasks = 1; + } break; + case GGML_OP_SOFT_MAX: + { + node->n_tasks = n_threads; + } break; + case GGML_OP_ROPE: + { + node->n_tasks = 1; + } break; + case GGML_OP_CONV_1D_1S: + case GGML_OP_CONV_1D_2S: + { + node->n_tasks = n_threads; + + GGML_ASSERT(node->src0->ne[3] == 1); + GGML_ASSERT(node->src1->ne[2] == 1); + GGML_ASSERT(node->src1->ne[3] == 1); + + size_t cur = 0; + const int nk = node->src0->ne[0]; + + if (node->src0->type == GGML_TYPE_F16 && + node->src1->type == GGML_TYPE_F32) { + cur = sizeof(ggml_fp16_t)*( + nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] + + ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1] + ); + } else if (node->src0->type == GGML_TYPE_F32 && + node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)*( + nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] + + ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1] + ); + } else { + GGML_ASSERT(false); + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_FLASH_ATTN: + { + node->n_tasks = n_threads; + + size_t cur = 0; + + const int ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL); + + if (node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2 + } + + if (node->src1->type == GGML_TYPE_F16) { + cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2 + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_FLASH_FF: + { + node->n_tasks = n_threads; + + size_t cur = 0; + + if (node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 + } + + if (node->src1->type == GGML_TYPE_F16) { + cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_NONE: + { + node->n_tasks = 1; + } break; + case GGML_OP_COUNT: + { + assert(false); + } break; + } + } + + if (cgraph->work != NULL && work_size > cgraph->work_size) { + assert(false); // TODO: better handling + } + + if (work_size > 0 && cgraph->work == NULL) { + cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1); + + GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size); + cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size); + } + } + + const int64_t perf_start_cycles = ggml_perf_cycles(); + const int64_t perf_start_time_us = ggml_perf_time_us(); + + for (int i = 0; i < cgraph->n_nodes; i++) { + GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes); + + struct ggml_tensor * node = cgraph->nodes[i]; + + // TODO: this could be used to avoid unnecessary computations, but it needs to be improved + //if (node->grad == NULL && node->perf_runs > 0) { + // continue; + //} + + const int64_t perf_node_start_cycles = ggml_perf_cycles(); + const int64_t perf_node_start_time_us = ggml_perf_time_us(); + + // INIT + struct ggml_compute_params params = { + /*.type =*/ GGML_TASK_INIT, + /*.ith =*/ 0, + /*.nth =*/ node->n_tasks, + /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0, + /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL, + }; + + ggml_compute_forward(¶ms, node); + + // COMPUTE + if (node->n_tasks > 1) { + if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { + atomic_store(&state_shared.has_work, false); + } + + while (atomic_load(&state_shared.has_work)) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + // launch thread pool + for (int j = 0; j < n_threads - 1; j++) { + workers[j].params = (struct ggml_compute_params) { + .type = GGML_TASK_COMPUTE, + .ith = j + 1, + .nth = node->n_tasks, + .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, + .wdata = cgraph->work ? cgraph->work->data : NULL, + }; + workers[j].node = node; + } + + atomic_fetch_sub(&state_shared.n_ready, 1); + + while (atomic_load(&state_shared.n_ready) > 0) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + atomic_store(&state_shared.has_work, true); + } + + params.type = GGML_TASK_COMPUTE; + ggml_compute_forward(¶ms, node); + + // wait for thread pool + if (node->n_tasks > 1) { + if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { + atomic_store(&state_shared.has_work, false); + } + + while (atomic_load(&state_shared.has_work)) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + atomic_fetch_sub(&state_shared.n_ready, 1); + + while (atomic_load(&state_shared.n_ready) != 0) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + } + + // FINALIZE + if (node->n_tasks > 1) { + if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { + atomic_store(&state_shared.has_work, false); + } + + while (atomic_load(&state_shared.has_work)) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + // launch thread pool + for (int j = 0; j < n_threads - 1; j++) { + workers[j].params = (struct ggml_compute_params) { + .type = GGML_TASK_FINALIZE, + .ith = j + 1, + .nth = node->n_tasks, + .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, + .wdata = cgraph->work ? cgraph->work->data : NULL, + }; + workers[j].node = node; + } + + atomic_fetch_sub(&state_shared.n_ready, 1); + + while (atomic_load(&state_shared.n_ready) > 0) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + atomic_store(&state_shared.has_work, true); + } + + params.type = GGML_TASK_FINALIZE; + ggml_compute_forward(¶ms, node); + + // wait for thread pool + if (node->n_tasks > 1) { + if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) { + atomic_store(&state_shared.has_work, false); + } + + while (atomic_load(&state_shared.has_work)) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + + atomic_fetch_sub(&state_shared.n_ready, 1); + + while (atomic_load(&state_shared.n_ready) != 0) { + ggml_lock_lock (&state_shared.spin); + ggml_lock_unlock(&state_shared.spin); + } + } + + // performance stats (node) + { + int64_t perf_cycles_cur = ggml_perf_cycles() - perf_node_start_cycles; + int64_t perf_time_us_cur = ggml_perf_time_us() - perf_node_start_time_us; + + node->perf_runs++; + node->perf_cycles += perf_cycles_cur; + node->perf_time_us += perf_time_us_cur; + } + } + + // join thread pool + if (n_threads > 1) { + atomic_store(&state_shared.stop, true); + atomic_store(&state_shared.has_work, true); + + for (int j = 0; j < n_threads - 1; j++) { + int rc = ggml_thread_join(workers[j].thrd, NULL); + assert(rc == 0); + UNUSED(rc); + } + + ggml_lock_destroy(&state_shared.spin); + } + + // performance stats (graph) + { + int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles; + int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us; + + cgraph->perf_runs++; + cgraph->perf_cycles += perf_cycles_cur; + cgraph->perf_time_us += perf_time_us_cur; + + GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n", + __func__, cgraph->perf_runs, + (double) perf_cycles_cur / (double) ggml_cycles_per_ms(), + (double) cgraph->perf_cycles / (double) ggml_cycles_per_ms() / (double) cgraph->perf_runs, + (double) perf_time_us_cur / 1000.0, + (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs); + } +} + +void ggml_graph_reset(struct ggml_cgraph * cgraph) { + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * grad = cgraph->grads[i]; + + if (grad) { + ggml_set_zero(grad); + } + } +} + +void ggml_graph_print(const struct ggml_cgraph * cgraph) { + int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0}; + + GGML_PRINT("=== GRAPH ===\n"); + + GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads); + GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size); + + GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + perf_total_per_op_us[node->op] += node->perf_time_us; + + GGML_PRINT(" - %3d: [ %6d, %6d, %6d] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", + i, + node->ne[0], node->ne[1], node->ne[2], + GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs, + (double) node->perf_cycles / (double) ggml_cycles_per_ms(), + (double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs, + (double) node->perf_time_us / 1000.0, + (double) node->perf_time_us / 1000.0 / node->perf_runs); + } + + GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs); + for (int i = 0; i < cgraph->n_leafs; i++) { + struct ggml_tensor * node = cgraph->leafs[i]; + + GGML_PRINT(" - %3d: [ %6d, %6d] %8s\n", + i, + node->ne[0], node->ne[1], + GGML_OP_LABEL[node->op]); + } + + for (int i = 0; i < GGML_OP_COUNT; i++) { + GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0); + } + + GGML_PRINT("========================================\n"); +} + +// check if node is part of the graph +static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + if (cgraph == NULL) { + return true; + } + + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i] == node) { + return true; + } + } + + return false; +} + +static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * parent = cgraph->nodes[i]; + + if (parent->grad == node) { + return parent; + } + } + + return NULL; +} + +void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) { + char color[16]; + + FILE * fp = fopen(filename, "w"); + assert(fp); + + fprintf(fp, "digraph G {\n"); + fprintf(fp, " newrank = true;\n"); + fprintf(fp, " rankdir = LR;\n"); + + for (int i = 0; i < gb->n_nodes; i++) { + struct ggml_tensor * node = gb->nodes[i]; + + if (ggml_graph_get_parent(gb, node) != NULL) { + continue; + } + + if (node->is_param) { + snprintf(color, sizeof(color), "yellow"); + } else if (node->grad) { + if (ggml_graph_find(gf, node)) { + snprintf(color, sizeof(color), "green"); + } else { + snprintf(color, sizeof(color), "lightblue"); + } + } else { + snprintf(color, sizeof(color), "white"); + } + + fprintf(fp, " \"%p\" [ \ +style = filled; fillcolor = %s; shape = record; \ +label=\"%d [%d, %d] | %s", + (void *) node, color, + i, node->ne[0], node->ne[1], + GGML_OP_SYMBOL[node->op]); + + if (node->grad) { + fprintf(fp, " | %s\"; ]\n", GGML_OP_SYMBOL[node->grad->op]); + } else { + fprintf(fp, "\"; ]\n"); + } + } + + for (int i = 0; i < gb->n_leafs; i++) { + struct ggml_tensor * node = gb->leafs[i]; + + snprintf(color, sizeof(color), "pink"); + + if (ggml_nelements(node) == 1) { + fprintf(fp, " \"%p\" [ \ +style = filled; fillcolor = %s; shape = record; \ +label=\"%.1e\"; ]\n", + (void *) node, color, ggml_get_f32_1d(node, 0)); + } else { + fprintf(fp, " \"%p\" [ \ +style = filled; fillcolor = %s; shape = record; \ +label=\"CONST %d [%d, %d]\"; ]\n", + (void *) node, color, + i, node->ne[0], node->ne[1]); + } + } + + for (int i = 0; i < gb->n_nodes; i++) { + struct ggml_tensor * node = gb->nodes[i]; + + struct ggml_tensor * parent = ggml_graph_get_parent(gb, node); + + if (node->src0) { + struct ggml_tensor * parent0 = ggml_graph_get_parent(gb, node->src0); + + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"x\"; ]\n", + parent0 ? (void *) parent0 : (void *) node->src0, + parent0 ? "g" : "x", + parent ? (void *) parent : (void *) node, + parent ? "g" : "x", + parent ? "empty" : "vee", + parent ? "dashed" : "solid"); + } + + if (node->src1) { + struct ggml_tensor * parent1 = ggml_graph_get_parent(gb, node->src1); + + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"y\"; ]\n", + parent1 ? (void *) parent1 : (void *) node->src1, + parent1 ? "g" : "x", + parent ? (void *) parent : (void *) node, + parent ? "g" : "x", + parent ? "empty" : "vee", + parent ? "dashed" : "solid"); + } + } + + for (int i = 0; i < gb->n_leafs; i++) { + struct ggml_tensor * node = gb->leafs[i]; + + if (node->src0) { + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"x\"; ]\n", + (void *) node->src0, "x", + (void *) node, "x"); + } + + if (node->src1) { + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"y\"; ]\n", + (void *) node->src1, "x", + (void *) node, "x"); + } + } + + fprintf(fp, "}\n"); + + fclose(fp); + + GGML_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename); +} + +//////////////////////////////////////////////////////////////////////////////// + +static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int ne = ggml_nelements(ps[p]) ; + // TODO: add function to set tensor from array + for (int j = 0; j < ne; ++j) { + ggml_set_f32_1d(ps[p], j, x[i++]); + } + } +} + +static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int ne = ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int j = 0; j < ne; ++j) { + x[i++] = ggml_get_f32_1d(ps[p], j); + } + } +} + +static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int ne = ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int j = 0; j < ne; ++j) { + g[i++] = ggml_get_f32_1d(ps[p]->grad, j); + } + } +} + +// +// ADAM +// +// ref: https://arxiv.org/pdf/1412.6980.pdf +// + +static enum ggml_opt_result ggml_opt_adam( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb) { + assert(ggml_is_scalar(f)); + + gf->n_threads = params.n_threads; + gb->n_threads = params.n_threads; + + // these will store the parameters we want to optimize + struct ggml_tensor * ps[GGML_MAX_PARAMS]; + + int np = 0; + int nx = 0; + for (int i = 0; i < gf->n_nodes; ++i) { + if (gf->nodes[i]->is_param) { + GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); + + assert(np < GGML_MAX_PARAMS); + + ps[np++] = gf->nodes[i]; + nx += ggml_nelements(gf->nodes[i]); + } + } + + // constants + const float alpha = params.adam.alpha; + const float beta1 = params.adam.beta1; + const float beta2 = params.adam.beta2; + const float eps = params.adam.eps; + + float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters + float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient + float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared + float * m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment + float * v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment + float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat + float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat + + float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values + + // initialize + ggml_vec_set_f32(nx, m, 0.0f); + ggml_vec_set_f32(nx, v, 0.0f); + + // update view + ggml_opt_get_params(np, ps, x); + + // compute the function value + ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(ctx, gb); + + float fx_prev = ggml_get_f32_1d(f, 0); + if (pf) { + pf[0] = fx_prev; + } + + int n_no_improvement = 0; + float fx_best = fx_prev; + + // run the optimizer + for (int t = 0; t < params.adam.n_iter; ++t) { + GGML_PRINT_DEBUG ("=== iter %d ===\n", t); + + GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0)); + GGML_PRINT_DEBUG_5("df/dx0 = %10.6f\n", ggml_get_f32_1d(ps[0]->grad, 0)); + GGML_PRINT_DEBUG_5("df/dx1 = %10.6f\n", ggml_get_f32_1d(ps[1]->grad, 0)); + + for (int i = 0; i < np; ++i) { + GGML_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i, + ggml_get_f32_1d(ps[i], 0), ggml_get_f32_1d(ps[i]->grad, 0)); + } + + const int64_t t_start_wall = ggml_time_us(); + const int64_t t_start_cpu = ggml_cycles(); + UNUSED(t_start_wall); + UNUSED(t_start_cpu); + + { + // update the gradient + ggml_opt_get_grad(np, ps, g1); + + // m_t = beta1*m_t-1 + (1 - beta1)*g_t + ggml_vec_scale_f32(nx, m, beta1); + ggml_vec_mad_f32 (nx, m, g1, 1.0f - beta1); + + // g2 = g1^2 + ggml_vec_sqr_f32 (nx, g2, g1); + + // v_t = beta2*v_t-1 + (1 - beta2)*g_t^2 + ggml_vec_scale_f32(nx, v, beta2); + ggml_vec_mad_f32 (nx, v, g2, 1.0f - beta2); + + // m^hat = m_t / (1 - beta1^t) + // v^hat = v_t / (1 - beta2^t) + // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps) + ggml_vec_cpy_f32 (nx, mh, m); + ggml_vec_cpy_f32 (nx, vh, v); + + ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1))); + ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, t + 1))); + + ggml_vec_sqrt_f32 (nx, vh, vh); + ggml_vec_acc1_f32 (nx, vh, eps); + + ggml_vec_div_f32 (nx, mh, mh, vh); + ggml_vec_sub_f32 (nx, x, x, mh); + + // update the parameters + ggml_opt_set_params(np, ps, x); + } + + ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(ctx, gb); + + const float fx = ggml_get_f32_1d(f, 0); + + // check convergence + if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) { + GGML_PRINT_DEBUG("converged\n"); + + return GGML_OPT_OK; + } + + // delta-based convergence test + if (pf != NULL) { + // need at least params.past iterations to start checking for convergence + if (params.past <= t) { + const float rate = (pf[t%params.past] - fx)/fx; + + if (fabs(rate) < params.delta) { + return GGML_OPT_OK; + } + } + + pf[t%params.past] = fx; + } + + // check for improvement + if (params.max_no_improvement > 0) { + if (fx_best > fx) { + fx_best = fx; + n_no_improvement = 0; + } else { + ++n_no_improvement; + + if (n_no_improvement >= params.max_no_improvement) { + return GGML_OPT_OK; + } + } + } + + fx_prev = fx; + + { + const int64_t t_end_cpu = ggml_cycles(); + GGML_PRINT_DEBUG("time iter: %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC); + UNUSED(t_end_cpu); + + const int64_t t_end_wall = ggml_time_us(); + GGML_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6); + UNUSED(t_end_wall); + } + } + + return GGML_OPT_DID_NOT_CONVERGE; +} + +// +// L-BFGS +// +// the L-BFGS implementation below is based on the following implementation: +// +// https://github.com/chokkan/liblbfgs +// + +struct ggml_lbfgs_iteration_data { + float alpha; + float ys; + float * s; + float * y; +}; + +static enum ggml_opt_result linesearch_backtracking( + struct ggml_context * ctx, + const struct ggml_opt_params * params, + int nx, + float * x, + float * fx, + float * g, + float * d, + float * step, + const float * xp, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + const int np, + struct ggml_tensor * ps[]) { + int count = 0; + + float width = 0.0f; + float dg = 0.0f; + float finit = 0.0f; + float dginit = 0.0f; + float dgtest = 0.0f; + + const float dec = 0.5f; + const float inc = 2.1f; + + if (*step <= 0.) { + return GGML_LINESEARCH_INVALID_PARAMETERS; + } + + // compute the initial gradient in the search direction + ggml_vec_dot_f32(nx, &dginit, g, d); + + // make sure that d points to a descent direction + if (0 < dginit) { + return GGML_LINESEARCH_FAIL; + } + + // initialize local variables + finit = *fx; + dgtest = params->lbfgs.ftol*dginit; + + while (true) { + ggml_vec_cpy_f32(nx, x, xp); + ggml_vec_mad_f32(nx, x, d, *step); + + // evaluate the function and gradient values + { + ggml_opt_set_params(np, ps, x); + + ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(ctx, gb); + + ggml_opt_get_grad(np, ps, g); + + *fx = ggml_get_f32_1d(f, 0); + } + + ++count; + + if (*fx > finit + (*step)*dgtest) { + width = dec; + } else { + // Armijo condition is satisfied + if (params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_ARMIJO) { + return count; + } + + ggml_vec_dot_f32(nx, &dg, g, d); + + // check the Wolfe condition + if (dg < params->lbfgs.wolfe * dginit) { + width = inc; + } else { + if(params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE) { + // regular Wolfe conditions + return count; + } + + if(dg > -params->lbfgs.wolfe*dginit) { + width = dec; + } else { + // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) + return count; + } + return count; + } + } + + if (*step < params->lbfgs.min_step) { + return GGML_LINESEARCH_MINIMUM_STEP; + } + if (*step > params->lbfgs.max_step) { + return GGML_LINESEARCH_MAXIMUM_STEP; + } + if (params->lbfgs.max_linesearch <= count) { + return GGML_LINESEARCH_MAXIMUM_ITERATIONS; + } + + (*step) *= width; + } + + return GGML_LINESEARCH_FAIL; +} + +static enum ggml_opt_result ggml_opt_lbfgs( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb) { + if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE || + params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { + if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1. <= params.lbfgs.wolfe) { + return GGML_OPT_INVALID_WOLFE; + } + } + + gf->n_threads = params.n_threads; + gb->n_threads = params.n_threads; + + const int m = params.lbfgs.m; + + // these will store the parameters we want to optimize + struct ggml_tensor * ps[GGML_MAX_PARAMS]; + + int np = 0; + int nx = 0; + for (int i = 0; i < gf->n_nodes; ++i) { + if (gf->nodes[i]->is_param) { + GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); + + assert(np < GGML_MAX_PARAMS); + + ps[np++] = gf->nodes[i]; + nx += ggml_nelements(gf->nodes[i]); + } + } + + float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters + float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters + float * g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient + float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient + float * d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction + + float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values + + float fx = 0.0f; // cost function value + float xnorm = 0.0f; // ||x|| + float gnorm = 0.0f; // ||g|| + float step = 0.0f; + + // initialize x from the graph nodes + ggml_opt_get_params(np, ps, x); + + // the L-BFGS memory + struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m); + + for (int i = 0; i < m; ++i) { + lm[i].alpha = 0.0f; + lm[i].ys = 0.0f; + lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; + lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; + } + + // evaluate the function value and its gradient + { + ggml_opt_set_params(np, ps, x); + + ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(ctx, gb); + + ggml_opt_get_grad(np, ps, g); + + fx = ggml_get_f32_1d(f, 0); + } + + if (pf) { + pf[0] = fx; + } + + float fx_best = fx; + + // search direction = -gradient + ggml_vec_neg_f32(nx, d, g); + + // ||x||, ||g|| + ggml_vec_norm_f32(nx, &xnorm, x); + ggml_vec_norm_f32(nx, &gnorm, g); + + if (xnorm < 1.0f) { + xnorm = 1.0f; + } + + // already optimized + if (gnorm/xnorm <= params.lbfgs.eps) { + return GGML_OPT_OK; + } + + // initial step + ggml_vec_norm_inv_f32(nx, &step, d); + + int j = 0; + int k = 1; + int ls = 0; + int end = 0; + int bound = 0; + int n_no_improvement = 0; + + float ys = 0.0f; + float yy = 0.0f; + float beta = 0.0f; + + while (true) { + // store the current position and gradient vectors + ggml_vec_cpy_f32(nx, xp, x); + ggml_vec_cpy_f32(nx, gp, g); + + ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps); + + if (ls < 0) { + // linesearch failed - go back to the previous point and return + ggml_vec_cpy_f32(nx, x, xp); + ggml_vec_cpy_f32(nx, g, gp); + + return ls; + } + + ggml_vec_norm_f32(nx, &xnorm, x); + ggml_vec_norm_f32(nx, &gnorm, g); + + GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0)); + + if (xnorm < 1.0) { + xnorm = 1.0; + } + if (gnorm/xnorm <= params.lbfgs.eps) { + // converged + return GGML_OPT_OK; + } + + // delta-based convergence test + if (pf != NULL) { + // need at least params.past iterations to start checking for convergence + if (params.past <= k) { + const float rate = (pf[k%params.past] - fx)/fx; + + if (fabs(rate) < params.delta) { + return GGML_OPT_OK; + } + } + + pf[k%params.past] = fx; + } + + // check for improvement + if (params.max_no_improvement > 0) { + if (fx < fx_best) { + fx_best = fx; + n_no_improvement = 0; + } else { + n_no_improvement++; + + if (n_no_improvement >= params.max_no_improvement) { + return GGML_OPT_OK; + } + } + } + + if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) { + // reached the maximum number of iterations + return GGML_OPT_DID_NOT_CONVERGE; + } + + // update vectors s and y: + // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}. + // y_{k+1} = g_{k+1} - g_{k}. + // + ggml_vec_sub_f32(nx, lm[end].s, x, xp); + ggml_vec_sub_f32(nx, lm[end].y, g, gp); + + // compute scalars ys and yy: + // ys = y^t \cdot s -> 1 / \rho. + // yy = y^t \cdot y. + // + ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s); + ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y); + + lm[end].ys = ys; + + // find new search direction + // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS + + bound = (m <= k) ? m : k; + k++; + end = (end + 1)%m; + + // initialize search direction with -g + ggml_vec_neg_f32(nx, d, g); + + j = end; + for (int i = 0; i < bound; ++i) { + j = (j + m - 1) % m; + // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1} + ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d); + lm[j].alpha /= lm[j].ys; + // q_{i} = q_{i+1} - \alpha_{i} y_{i} + ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha); + } + + ggml_vec_scale_f32(nx, d, ys/yy); + + for (int i = 0; i < bound; ++i) { + // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i} + ggml_vec_dot_f32(nx, &beta, lm[j].y, d); + beta /= lm[j].ys; + // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j} + ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta); + j = (j + 1)%m; + } + + step = 1.0; + } + + return GGML_OPT_DID_NOT_CONVERGE; +} + +struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { + struct ggml_opt_params result; + + switch (type) { + case GGML_OPT_ADAM: + { + result = (struct ggml_opt_params) { + .type = GGML_OPT_ADAM, + .n_threads = 1, + .past = 0, + .delta = 1e-5f, + + .max_no_improvement = 100, + + .print_forward_graph = true, + .print_backward_graph = true, + + .adam = { + .n_iter = 10000, + .alpha = 0.001f, + .beta1 = 0.9f, + .beta2 = 0.999f, + .eps = 1e-8f, + .eps_f = 1e-5f, + .eps_g = 1e-3f, + }, + }; + } break; + case GGML_OPT_LBFGS: + { + result = (struct ggml_opt_params) { + .type = GGML_OPT_LBFGS, + .n_threads = 1, + .past = 0, + .delta = 1e-5f, + + .max_no_improvement = 0, + + .print_forward_graph = true, + .print_backward_graph = true, + + .lbfgs = { + .m = 6, + .n_iter = 100, + .max_linesearch = 20, + + .eps = 1e-5f, + .ftol = 1e-4f, + .wolfe = 0.9f, + .min_step = 1e-20f, + .max_step = 1e+20f, + + .linesearch = GGML_LINESEARCH_DEFAULT, + }, + }; + } break; + } + + return result; +} + +enum ggml_opt_result ggml_opt( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f) { + bool free_ctx = false; + if (ctx == NULL) { + struct ggml_init_params params_ctx = { + .mem_size = 16*1024*1024, + .mem_buffer = NULL, + }; + + ctx = ggml_init(params_ctx); + if (ctx == NULL) { + return GGML_OPT_NO_CONTEXT; + } + + free_ctx = true; + } + + enum ggml_opt_result result = GGML_OPT_OK; + + // build forward + backward compute graphs + struct ggml_cgraph gf = ggml_build_forward (f); + struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, false); + + switch (params.type) { + case GGML_OPT_ADAM: + { + result = ggml_opt_adam(ctx, params, f, &gf, &gb); + } break; + case GGML_OPT_LBFGS: + { + result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb); + } break; + } + + if (params.print_forward_graph) { + ggml_graph_print (&gf); + ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot"); + } + + if (params.print_backward_graph) { + ggml_graph_print (&gb); + ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot"); + } + + if (free_ctx) { + ggml_free(ctx); + } + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +int ggml_cpu_has_avx(void) { +#if defined(__AVX__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx2(void) { +#if defined(__AVX2__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx512(void) { +#if defined(__AVX512F__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_fma(void) { +#if defined(__FMA__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_neon(void) { +#if defined(__ARM_NEON) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_arm_fma(void) { +#if defined(__ARM_FEATURE_FMA) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_f16c(void) { +#if defined(__F16C__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_fp16_va(void) { +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_wasm_simd(void) { +#if defined(__wasm_simd128__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_blas(void) { +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_sse3(void) { +#if defined(__SSE3__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_vsx(void) { +#if defined(__POWER9_VECTOR__) + return 1; +#else + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////// diff --git a/bindings/ruby/ext/ggml.h b/bindings/ruby/ext/ggml.h new file mode 100644 index 0000000..18f317b --- /dev/null +++ b/bindings/ruby/ext/ggml.h @@ -0,0 +1,748 @@ +#pragma once + +// +// GGML Tensor Library +// +// This documentation is still a work in progress. +// If you wish some specific topics to be covered, feel free to drop a comment: +// +// https://github.com/ggerganov/whisper.cpp/issues/40 +// +// ## Overview +// +// This library implements: +// +// - a set of tensor operations +// - automatic differentiation +// - basic optimization algorithms +// +// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes, +// but is not limited to, the following: +// +// - linear regression +// - support vector machines +// - neural networks +// +// The library allows the user to define a certain function using the available tensor operations. This function +// definition is represented internally via a computation graph. Each tensor operation in the function definition +// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the +// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized +// using one of the available optimization algorithms. +// +// For example, here we define the function: f(x) = a*x^2 + b +// +// { +// struct ggml_init_params params = { +// .mem_size = 16*1024*1024, +// .mem_buffer = NULL, +// }; +// +// // memory allocation happens here +// struct ggml_context * ctx = ggml_init(params); +// +// struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// +// ggml_set_param(ctx, x); // x is an input variable +// +// struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// struct ggml_tensor * x2 = ggml_mul(ctx, x, x); +// struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b); +// +// ... +// } +// +// Notice that the function definition above does not involve any actual computation. The computation is performed only +// when the user explicitly requests it. For example, to compute the function's value at x = 2.0: +// +// { +// ... +// +// struct ggml_cgraph gf = ggml_build_forward(f); +// +// // set the input variable and parameter values +// ggml_set_f32(x, 2.0f); +// ggml_set_f32(a, 3.0f); +// ggml_set_f32(b, 4.0f); +// +// ggml_graph_compute(ctx0, &gf); +// +// printf("f = %f\n", ggml_get_f32_1d(f, 0)); +// +// ... +// } +// +// The actual computation is performed in the ggml_graph_compute() function. +// +// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the +// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know +// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory +// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was +// actually needed. +// +// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic +// differentiation and optimization algorithms. +// +// The described approach allows to define the function graph once and then compute its forward or backward graphs +// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way +// the user can avoid the memory allocation overhead at runtime. +// +// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class +// citizens, but in theory the library can be extended to support FP8 and integer data types. +// +// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary +// and binary operations. Most of the available operations fall into one of these two categories. With time, it became +// clear that the library needs to support more complex operations. The way to support these operations is not clear +// yet, but a few examples are demonstrated in the following operations: +// +// - ggml_permute() +// - ggml_conv_1d_1s() +// - ggml_conv_1d_2s() +// +// For each tensor operator, the library implements a forward and backward computation function. The forward function +// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the +// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a +// calculus class, or watch the following video: +// +// What is Automatic Differentiation? +// https://www.youtube.com/watch?v=wG_nF1awSSY +// +// +// ## Tensor data (struct ggml_tensor) +// +// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of +// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains +// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example: +// +// { +// struct ggml_tensor * c = ggml_add(ctx, a, b); +// +// assert(c->src[0] == a); +// assert(c->src[1] == b); +// } +// +// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the +// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows +// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and +// permutation. All tensor operations have to take the stride into account and not assume that the tensor is +// contiguous in memory. +// +// The data of the tensor is accessed via the "data" pointer. For example: +// +// { +// struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, 3); +// +// // a[1, 2] = 1.0f; +// *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f; +// +// // a[2, 0] = 2.0f; +// *(float *) ((char *) a->data + 0*a->nb[1] + 2*a->nb[0]) = 2.0f; +// +// ... +// } +// +// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used. +// +// ## The matrix multiplication operator (ggml_mul_mat) +// +// TODO +// +// +// ## Multi-threading +// +// TODO +// +// +// ## Overview of ggml.c +// +// TODO +// +// +// ## SIMD optimizations +// +// TODO +// +// +// ## Debugging ggml +// +// TODO +// +// + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include +#include + +#define GGML_MAX_DIMS 4 +#define GGML_MAX_NODES 4096 +#define GGML_MAX_PARAMS 16 +#define GGML_MAX_CONTEXTS 64 +#define GGML_MAX_OPT 4 + +#ifdef __ARM_NEON +// we use the built-in 16-bit float type +typedef __fp16 ggml_fp16_t; +#else +typedef uint16_t ggml_fp16_t; +#endif + +// convert FP16 <-> FP32 +float ggml_fp16_to_fp32(ggml_fp16_t x); +ggml_fp16_t ggml_fp32_to_fp16(float x); + +struct ggml_object; +struct ggml_context; + +enum ggml_type { + GGML_TYPE_I8, + GGML_TYPE_I16, + GGML_TYPE_I32, + GGML_TYPE_F16, + GGML_TYPE_F32, + GGML_TYPE_COUNT, +}; + +// available tensor operations: +enum ggml_op { + GGML_OP_NONE = 0, + + GGML_OP_DUP, + GGML_OP_ADD, + GGML_OP_SUB, + GGML_OP_MUL, + GGML_OP_DIV, + GGML_OP_SQR, + GGML_OP_SQRT, + GGML_OP_SUM, + GGML_OP_MEAN, + GGML_OP_REPEAT, + GGML_OP_ABS, + GGML_OP_SGN, + GGML_OP_NEG, + GGML_OP_STEP, + GGML_OP_RELU, + GGML_OP_GELU, + GGML_OP_NORM, // normalize + + GGML_OP_MUL_MAT, + + GGML_OP_SCALE, + GGML_OP_CPY, + GGML_OP_RESHAPE, + GGML_OP_VIEW, + GGML_OP_PERMUTE, + GGML_OP_TRANSPOSE, + GGML_OP_GET_ROWS, + GGML_OP_DIAG_MASK_INF, + GGML_OP_SOFT_MAX, + GGML_OP_ROPE, + GGML_OP_CONV_1D_1S, + GGML_OP_CONV_1D_2S, + + GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_FF, + + GGML_OP_COUNT, +}; + +// n-dimensional tensor +struct ggml_tensor { + enum ggml_type type; + + int n_dims; + int ne[GGML_MAX_DIMS]; // number of elements + size_t nb[GGML_MAX_DIMS]; // stride in bytes: + // nb[0] = sizeof(type) + // nb[1] = nb[0] * ne[0] + padding + // nb[i] = nb[i-1] * ne[i-1] + + // compute data + enum ggml_op op; + + bool is_param; + + struct ggml_tensor * grad; + struct ggml_tensor * src0; + struct ggml_tensor * src1; + struct ggml_tensor * opt[GGML_MAX_OPT]; + + // thread scheduling + int n_tasks; + + // performance + int perf_runs; + int64_t perf_cycles; + int64_t perf_time_us; + + void * data; + char padding[8]; +}; + +// computation graph +struct ggml_cgraph { + int n_nodes; + int n_leafs; + int n_threads; + + size_t work_size; + struct ggml_tensor * work; + + struct ggml_tensor * nodes[GGML_MAX_NODES]; + struct ggml_tensor * grads[GGML_MAX_NODES]; + struct ggml_tensor * leafs[GGML_MAX_NODES]; + + // performance + int perf_runs; + int64_t perf_cycles; + int64_t perf_time_us; +}; + +// scratch buffer +struct ggml_scratch { + size_t offs; + size_t size; + void * data; +}; + +struct ggml_init_params { + // memory pool + size_t mem_size; // bytes + void * mem_buffer; // if NULL, memory will be allocated internally +}; + +void ggml_time_init(void); // call this once at the beginning of the program +int64_t ggml_time_ms(void); +int64_t ggml_time_us(void); +int64_t ggml_cycles(void); +int64_t ggml_cycles_per_ms(void); + +void ggml_print_object (const struct ggml_object * obj); +void ggml_print_objects(const struct ggml_context * ctx); + +int ggml_nelements(const struct ggml_tensor * tensor); +size_t ggml_nbytes (const struct ggml_tensor * tensor); + +size_t ggml_type_size (enum ggml_type type); +size_t ggml_element_size(const struct ggml_tensor * tensor); + +struct ggml_context * ggml_init(struct ggml_init_params params); +void ggml_free(struct ggml_context * ctx); + +size_t ggml_used_mem(const struct ggml_context * ctx); + +size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch); + +struct ggml_tensor * ggml_new_tensor( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int *ne); + +struct ggml_tensor * ggml_new_tensor_1d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0); + +struct ggml_tensor * ggml_new_tensor_2d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1); + +struct ggml_tensor * ggml_new_tensor_3d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1, + int ne2); + +struct ggml_tensor * ggml_new_tensor_4d( + struct ggml_context * ctx, + enum ggml_type type, + int ne0, + int ne1, + int ne2, + int ne3); + +struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value); +struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); + +struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); +struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src); + +struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); +struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); +struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); + +int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i); +void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); + +float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); +void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); + + void * ggml_get_data (const struct ggml_tensor * tensor); +float * ggml_get_data_f32(const struct ggml_tensor * tensor); + +// +// operations on tensors with backpropagation +// + +struct ggml_tensor * ggml_dup( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_sub( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_mul( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_div( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_sqrt( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// return scalar +// TODO: compute sum along rows +struct ggml_tensor * ggml_sum( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// mean along rows +struct ggml_tensor * ggml_mean( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// if a is the same shape as b, and a is not parameter, return a +// otherwise, return a new tensor: repeat(a) to fit in b +struct ggml_tensor * ggml_repeat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_abs( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_sgn( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_neg( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_step( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_relu( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// TODO: double-check this computation is correct +struct ggml_tensor * ggml_gelu( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// normalize along rows +// TODO: eps is hardcoded to 1e-5 for now +struct ggml_tensor * ggml_norm( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// A: m rows, n columns +// B: p rows, n columns (i.e. we transpose it internally) +// result is m columns, p rows +struct ggml_tensor * ggml_mul_mat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +// +// operations on tensors without backpropagation +// + +// in-place, returns view(a) +struct ggml_tensor * ggml_scale( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +// a -> b, return view(b) +struct ggml_tensor * ggml_cpy( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +// return view(a), b specifies the new shape +// TODO: when we start computing gradient, make a copy instead of view +struct ggml_tensor * ggml_reshape( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +// return view(a) +// TODO: when we start computing gradient, make a copy instead of view +struct ggml_tensor * ggml_reshape_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1); + +// return view(a) +// TODO: when we start computing gradient, make a copy instead of view +struct ggml_tensor * ggml_reshape_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + int ne2); + +// offset in bytes +struct ggml_tensor * ggml_view_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + size_t offset); + +struct ggml_tensor * ggml_view_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + size_t nb1, // row stride in bytes + size_t offset); + +struct ggml_tensor * ggml_permute( + struct ggml_context * ctx, + struct ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3); + +// alias for ggml_permute(ctx, a, 1, 0, 2, 3) +struct ggml_tensor * ggml_transpose( + struct ggml_context * ctx, + struct ggml_tensor * a); + +struct ggml_tensor * ggml_get_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +// set elements above the diagonal to -INF +// in-place, returns view(a) +struct ggml_tensor * ggml_diag_mask_inf( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + +// in-place, returns view(a) +struct ggml_tensor * ggml_soft_max( + struct ggml_context * ctx, + struct ggml_tensor * a); + +// rotary position embedding +// in-place, returns view(a) +// if mode == 1, skip n_past elements +// TODO: avoid creating a new tensor every time +struct ggml_tensor * ggml_rope( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode); + +// padding = 1 +// TODO: we don't support extra parameters for now +// that's why we are hard-coding the stride, padding, and dilation +// not great .. +struct ggml_tensor * ggml_conv_1d_1s( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_conv_1d_2s( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + +struct ggml_tensor * ggml_flash_attn( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + bool masked); + +struct ggml_tensor * ggml_flash_ff( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b0, + struct ggml_tensor * b1, + struct ggml_tensor * c0, + struct ggml_tensor * c1); + +// +// automatic differentiation +// + +void ggml_set_param( + struct ggml_context * ctx, + struct ggml_tensor * tensor); + +void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); + +struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor); +struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep); + +void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph); +void ggml_graph_reset (struct ggml_cgraph * cgraph); + +// print info and performance information for the graph +void ggml_graph_print(const struct ggml_cgraph * cgraph); + +// dump the graph into a file using the dot format +void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); + +// +// optimization +// + +// optimization methods +enum ggml_opt_type { + GGML_OPT_ADAM, + GGML_OPT_LBFGS, +}; + +// linesearch methods +enum ggml_linesearch { + GGML_LINESEARCH_DEFAULT = 1, + + GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0, + GGML_LINESEARCH_BACKTRACKING_WOLFE = 1, + GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2, +}; + +// optimization return values +enum ggml_opt_result { + GGML_OPT_OK = 0, + GGML_OPT_DID_NOT_CONVERGE, + GGML_OPT_NO_CONTEXT, + GGML_OPT_INVALID_WOLFE, + GGML_OPT_FAIL, + + GGML_LINESEARCH_FAIL = -128, + GGML_LINESEARCH_MINIMUM_STEP, + GGML_LINESEARCH_MAXIMUM_STEP, + GGML_LINESEARCH_MAXIMUM_ITERATIONS, + GGML_LINESEARCH_INVALID_PARAMETERS, +}; + +// optimization parameters +// +// see ggml.c (ggml_opt_default_params) for default values +// +struct ggml_opt_params { + enum ggml_opt_type type; + + int n_threads; + + // delta-based convergence test + // + // if past == 0 - disabled + // if past > 0: + // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|) + // + int past; + float delta; + + // maximum number of iterations without improvement + // + // if 0 - disabled + // if > 0: + // assume convergence if no cost improvement in this number of iterations + // + int max_no_improvement; + + bool print_forward_graph; + bool print_backward_graph; + + // ADAM parameters + struct { + int n_iter; + + float alpha; // learning rate + float beta1; + float beta2; + float eps; // epsilon for numerical stability + float eps_f; // epsilon for convergence test + float eps_g; // epsilon for convergence test + } adam; + + // LBFGS parameters + struct { + int m; // number of corrections to approximate the inv. Hessian + int n_iter; + int max_linesearch; + + float eps; // convergence tolerance + float ftol; // line search tolerance + float wolfe; + float min_step; + float max_step; + + enum ggml_linesearch linesearch; + } lbfgs; +}; + +struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type); + +// optimize the function defined by the tensor f +enum ggml_opt_result ggml_opt( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f); + +// +// system info +// + +int ggml_cpu_has_avx(void); +int ggml_cpu_has_avx2(void); +int ggml_cpu_has_avx512(void); +int ggml_cpu_has_fma(void); +int ggml_cpu_has_neon(void); +int ggml_cpu_has_arm_fma(void); +int ggml_cpu_has_f16c(void); +int ggml_cpu_has_fp16_va(void); +int ggml_cpu_has_wasm_simd(void); +int ggml_cpu_has_blas(void); +int ggml_cpu_has_sse3(void); +int ggml_cpu_has_vsx(void); + +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c new file mode 100644 index 0000000..4abfab5 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper.c @@ -0,0 +1,92 @@ +#include +#include "ruby_whisper.h" + +VALUE mWhisper; +VALUE cContext; +VALUE cParams; + +static void ruby_whisper_free(ruby_whisper *rw) { + if (rw->context) { + whisper_free(rw->context); + rw->context = NULL; + } +} +static void ruby_whisper_params_free(ruby_whisper_params *rwp) { + if (rwp->params) { + free(rwp->params); + rwp->params = NULL; + } +} + +void rb_whisper_mark(ruby_whisper *rw) { + // call rb_gc_mark on any ruby references in rw +} + +void rb_whisper_free(ruby_whisper *rw) { + ruby_whisper_free(rw); + free(rw); +} + +void rb_whisper_params_mark(ruby_whisper_params *rwp) { +} + +void rb_whisper_params_free(ruby_whisper_params *rwp) { + ruby_whisper_params_free(rwp); + free(rwp); +} + +static VALUE ruby_whisper_allocate(VALUE klass) { + ruby_whisper *rw; + rw = ALLOC(ruby_whisper); + rw->context = NULL; + return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw); +} + +static VALUE ruby_whisper_params_allocate(VALUE klass) { + ruby_whisper_params *rwp; + rwp = ALLOC(ruby_whisper_params); + rwp->params = ALLOC(struct whisper_full_params); + return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); +} + +static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { + ruby_whisper *rw; + VALUE whisper_model_file_path; + + // TODO: we can support init from buffer here too maybe another ruby object to expose + rb_scan_args(argc, argv, "01", &whisper_model_file_path); + Data_Get_Struct(self, ruby_whisper, rw); + + if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) { + rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); + } + rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path)); + return self; +} + +/* + * params.auto_detection = true|false + */ +static VALUE ruby_whisper_params_set_auto_detection(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + if (value == Qfalse || value == Qnil) { + rwp->params->language = "auto"; + } else { + rwp->params->language = NULL; + } + return value; +} + +void Init_whisper() { + mWhisper = rb_define_module("Whisper"); + cContext = rb_define_class_under(mWhisper, "Context", rb_cObject); + cParams = rb_define_class_under(mWhisper, "Params", rb_cObject); + + rb_define_alloc_func(cContext, ruby_whisper_allocate); + rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1); + + rb_define_alloc_func(cParams, ruby_whisper_params_allocate); + + rb_define_method(cParams, "auto_detection=", ruby_whisper_params_set_auto_detection, 1); +} diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h new file mode 100644 index 0000000..246d133 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper.h @@ -0,0 +1,14 @@ +#ifndef __RUBY_WHISPER_H +#define __RUBY_WHISPER_H + +#include "whisper.h" + +typedef struct { + struct whisper_context *context; +} ruby_whisper; + +typedef struct { + struct whisper_full_params *params; +} ruby_whisper_params; + +#endif diff --git a/bindings/ruby/ext/whisper.cpp b/bindings/ruby/ext/whisper.cpp new file mode 100644 index 0000000..24e16bd --- /dev/null +++ b/bindings/ruby/ext/whisper.cpp @@ -0,0 +1,4814 @@ +#define WHISPER_BUILD +#include "whisper.h" + +#include "ggml.h" + +#include +#include +#define _USE_MATH_DEFINES +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(GGML_BIG_ENDIAN) +#include + +template +static T byteswap(T value) { + return std::byteswap(value); +} + +template<> +float byteswap(float value) { + return std::bit_cast(byteswap(std::bit_cast(value))); +} + +template +static void byteswap_tensor_data(ggml_tensor * tensor) { + T * datum = reinterpret_cast(tensor->data); + for (int i = 0; i < ggml_nelements(tensor); i++) { + datum[i] = byteswap(datum[i]); + } +} + +static void byteswap_tensor(ggml_tensor * tensor) { + switch (tensor->type) { + case GGML_TYPE_I16: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_F16: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_I32: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_F32: { + byteswap_tensor_data(tensor); + break; + } + default: { // GML_TYPE_I8 + break; + } + } +} + +#define BYTESWAP_VALUE(d) d = byteswap(d) +#define BYTESWAP_FILTERS(f) \ + do { \ + for (auto & datum : f.data) { \ + datum = byteswap(datum); \ + } \ + } while (0) +#define BYTESWAP_TENSOR(t) \ + do { \ + byteswap_tensor(tensor); \ + } while (0) +#else +#define BYTESWAP_VALUE(d) do {} while (0) +#define BYTESWAP_FILTERS(f) do {} while (0) +#define BYTESWAP_TENSOR(t) do {} while (0) +#endif + +#define WHISPER_ASSERT(x) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +// define this to enable verbose trace logging - useful for debugging purposes +//#define WHISPER_DEBUG + +#if defined(WHISPER_DEBUG) +#define WHISPER_PRINT_DEBUG(...) \ + do { \ + fprintf(stderr, __VA_ARGS__); \ + } while (0) +#else +#define WHISPER_PRINT_DEBUG(...) +#endif + +#define WHISPER_USE_FLASH_ATTN +//#define WHISPER_USE_FLASH_FF +#define WHISPER_MAX_DECODERS 16 + +#define WHISPER_USE_SCRATCH +#define WHISPER_MAX_SCRATCH_BUFFERS 16 + +// available whisper models +enum e_model { + MODEL_UNKNOWN, + MODEL_TINY, + MODEL_BASE, + MODEL_SMALL, + MODEL_MEDIUM, + MODEL_LARGE, +}; + +static const std::map> g_lang = { + { "en", { 0, "english", } }, + { "zh", { 1, "chinese", } }, + { "de", { 2, "german", } }, + { "es", { 3, "spanish", } }, + { "ru", { 4, "russian", } }, + { "ko", { 5, "korean", } }, + { "fr", { 6, "french", } }, + { "ja", { 7, "japanese", } }, + { "pt", { 8, "portuguese", } }, + { "tr", { 9, "turkish", } }, + { "pl", { 10, "polish", } }, + { "ca", { 11, "catalan", } }, + { "nl", { 12, "dutch", } }, + { "ar", { 13, "arabic", } }, + { "sv", { 14, "swedish", } }, + { "it", { 15, "italian", } }, + { "id", { 16, "indonesian", } }, + { "hi", { 17, "hindi", } }, + { "fi", { 18, "finnish", } }, + { "vi", { 19, "vietnamese", } }, + { "iw", { 20, "hebrew", } }, + { "uk", { 21, "ukrainian", } }, + { "el", { 22, "greek", } }, + { "ms", { 23, "malay", } }, + { "cs", { 24, "czech", } }, + { "ro", { 25, "romanian", } }, + { "da", { 26, "danish", } }, + { "hu", { 27, "hungarian", } }, + { "ta", { 28, "tamil", } }, + { "no", { 29, "norwegian", } }, + { "th", { 30, "thai", } }, + { "ur", { 31, "urdu", } }, + { "hr", { 32, "croatian", } }, + { "bg", { 33, "bulgarian", } }, + { "lt", { 34, "lithuanian", } }, + { "la", { 35, "latin", } }, + { "mi", { 36, "maori", } }, + { "ml", { 37, "malayalam", } }, + { "cy", { 38, "welsh", } }, + { "sk", { 39, "slovak", } }, + { "te", { 40, "telugu", } }, + { "fa", { 41, "persian", } }, + { "lv", { 42, "latvian", } }, + { "bn", { 43, "bengali", } }, + { "sr", { 44, "serbian", } }, + { "az", { 45, "azerbaijani", } }, + { "sl", { 46, "slovenian", } }, + { "kn", { 47, "kannada", } }, + { "et", { 48, "estonian", } }, + { "mk", { 49, "macedonian", } }, + { "br", { 50, "breton", } }, + { "eu", { 51, "basque", } }, + { "is", { 52, "icelandic", } }, + { "hy", { 53, "armenian", } }, + { "ne", { 54, "nepali", } }, + { "mn", { 55, "mongolian", } }, + { "bs", { 56, "bosnian", } }, + { "kk", { 57, "kazakh", } }, + { "sq", { 58, "albanian", } }, + { "sw", { 59, "swahili", } }, + { "gl", { 60, "galician", } }, + { "mr", { 61, "marathi", } }, + { "pa", { 62, "punjabi", } }, + { "si", { 63, "sinhala", } }, + { "km", { 64, "khmer", } }, + { "sn", { 65, "shona", } }, + { "yo", { 66, "yoruba", } }, + { "so", { 67, "somali", } }, + { "af", { 68, "afrikaans", } }, + { "oc", { 69, "occitan", } }, + { "ka", { 70, "georgian", } }, + { "be", { 71, "belarusian", } }, + { "tg", { 72, "tajik", } }, + { "sd", { 73, "sindhi", } }, + { "gu", { 74, "gujarati", } }, + { "am", { 75, "amharic", } }, + { "yi", { 76, "yiddish", } }, + { "lo", { 77, "lao", } }, + { "uz", { 78, "uzbek", } }, + { "fo", { 79, "faroese", } }, + { "ht", { 80, "haitian creole", } }, + { "ps", { 81, "pashto", } }, + { "tk", { 82, "turkmen", } }, + { "nn", { 83, "nynorsk", } }, + { "mt", { 84, "maltese", } }, + { "sa", { 85, "sanskrit", } }, + { "lb", { 86, "luxembourgish", } }, + { "my", { 87, "myanmar", } }, + { "bo", { 88, "tibetan", } }, + { "tl", { 89, "tagalog", } }, + { "mg", { 90, "malagasy", } }, + { "as", { 91, "assamese", } }, + { "tt", { 92, "tatar", } }, + { "haw", { 93, "hawaiian", } }, + { "ln", { 94, "lingala", } }, + { "ha", { 95, "hausa", } }, + { "ba", { 96, "bashkir", } }, + { "jw", { 97, "javanese", } }, + { "su", { 98, "sundanese", } }, +}; + +static const size_t MB = 1024*1024; + +static const std::map MEM_REQ_SCRATCH0 = { + { MODEL_TINY, 12ull*MB }, + { MODEL_BASE, 15ull*MB }, + { MODEL_SMALL, 23ull*MB }, + { MODEL_MEDIUM, 31ull*MB }, + { MODEL_LARGE, 38ull*MB }, +}; + +static const std::map MEM_REQ_SCRATCH1 = { + { MODEL_TINY, 18ull*MB }, + { MODEL_BASE, 24ull*MB }, + { MODEL_SMALL, 36ull*MB }, + { MODEL_MEDIUM, 48ull*MB }, + { MODEL_LARGE, 60ull*MB }, +}; + +static const std::map MEM_REQ_SCRATCH2 = { + { MODEL_TINY, 4ull*MB }, + { MODEL_BASE, 4ull*MB }, + { MODEL_SMALL, 6ull*MB }, + { MODEL_MEDIUM, 7ull*MB }, + { MODEL_LARGE, 9ull*MB }, +}; + +static const std::map MEM_REQ_SCRATCH3 = { + { MODEL_TINY, 4ull*MB }, + { MODEL_BASE, 4ull*MB }, + { MODEL_SMALL, 6ull*MB }, + { MODEL_MEDIUM, 7ull*MB }, + { MODEL_LARGE, 9ull*MB }, +}; + +static const std::map MEM_REQ_MODEL = { + { MODEL_TINY, 74ull*MB }, + { MODEL_BASE, 142ull*MB }, + { MODEL_SMALL, 466ull*MB }, + { MODEL_MEDIUM, 1464ull*MB }, + { MODEL_LARGE, 2952ull*MB }, +}; + +static const std::map MEM_REQ_KV_SELF = { + { MODEL_TINY, 3ull*MB }, + { MODEL_BASE, 6ull*MB }, + { MODEL_SMALL, 16ull*MB }, + { MODEL_MEDIUM, 43ull*MB }, + { MODEL_LARGE, 71ull*MB }, +}; + +static const std::map MEM_REQ_KV_CROSS = { + { MODEL_TINY, 9ull*MB }, + { MODEL_BASE, 18ull*MB }, + { MODEL_SMALL, 53ull*MB }, + { MODEL_MEDIUM, 141ull*MB }, + { MODEL_LARGE, 235ull*MB }, +}; + +static const std::map MEM_REQ_ENCODE = { + { MODEL_TINY, 6ull*MB }, + { MODEL_BASE, 8ull*MB }, + { MODEL_SMALL, 13ull*MB }, + { MODEL_MEDIUM, 22ull*MB }, + { MODEL_LARGE, 33ull*MB }, +}; + +static const std::map MEM_REQ_DECODE = { + { MODEL_TINY, 3ull*MB }, + { MODEL_BASE, 5ull*MB }, + { MODEL_SMALL, 10ull*MB }, + { MODEL_MEDIUM, 18ull*MB }, + { MODEL_LARGE, 27ull*MB }, +}; + +struct whisper_mel { + int n_len; + int n_mel; + + std::vector data; +}; + +struct whisper_filters { + int32_t n_mel; + int32_t n_fft; + + std::vector data; +}; + +struct whisper_vocab { + using id = int32_t; + using token = std::string; + + int n_vocab = 51864; + + std::map token_to_id; + std::map id_to_token; + + id token_eot = 50256; + id token_sot = 50257; + id token_prev = 50360; + id token_solm = 50361; // ?? + id token_not = 50362; // no timestamps + id token_beg = 50363; + + // available tasks + static const id token_translate = 50358; + static const id token_transcribe = 50359; + + bool is_multilingual() const { + return n_vocab == 51865; + } +}; + +struct whisper_segment { + int64_t t0; + int64_t t1; + + std::string text; + + std::vector tokens; +}; + +// medium +// hparams: { +// 'n_mels': 80, +// 'n_vocab': 51864, +// 'n_audio_ctx': 1500, +// 'n_audio_state': 1024, +// 'n_audio_head': 16, +// 'n_audio_layer': 24, +// 'n_text_ctx': 448, +// 'n_text_state': 1024, +// 'n_text_head': 16, +// 'n_text_layer': 24 +// } +// +// default hparams (Whisper tiny) +struct whisper_hparams { + int32_t n_vocab = 51864; + int32_t n_audio_ctx = 1500; + int32_t n_audio_state = 384; + int32_t n_audio_head = 6; + int32_t n_audio_layer = 4; + int32_t n_text_ctx = 448; + int32_t n_text_state = 384; + int32_t n_text_head = 6; + int32_t n_text_layer = 4; + int32_t n_mels = 80; + int32_t f16 = 1; +}; + +// audio encoding layer +struct whisper_layer_encoder { + // encoder.blocks.*.attn_ln + struct ggml_tensor * attn_ln_0_w; + struct ggml_tensor * attn_ln_0_b; + + // encoder.blocks.*.attn.out + struct ggml_tensor * attn_ln_1_w; + struct ggml_tensor * attn_ln_1_b; + + // encoder.blocks.*.attn.query + struct ggml_tensor * attn_q_w; + struct ggml_tensor * attn_q_b; + + // encoder.blocks.*.attn.key + struct ggml_tensor * attn_k_w; + + // encoder.blocks.*.attn.value + struct ggml_tensor * attn_v_w; + struct ggml_tensor * attn_v_b; + + // encoder.blocks.*.mlp_ln + struct ggml_tensor * mlp_ln_w; + struct ggml_tensor * mlp_ln_b; + + // encoder.blocks.*.mlp.0 + struct ggml_tensor * mlp_0_w; + struct ggml_tensor * mlp_0_b; + + // encoder.blocks.*.mlp.2 + struct ggml_tensor * mlp_1_w; + struct ggml_tensor * mlp_1_b; +}; + +// token decoding layer +struct whisper_layer_decoder { + // decoder.blocks.*.attn_ln + struct ggml_tensor * attn_ln_0_w; + struct ggml_tensor * attn_ln_0_b; + + // decoder.blocks.*.attn.out + struct ggml_tensor * attn_ln_1_w; + struct ggml_tensor * attn_ln_1_b; + + // decoder.blocks.*.attn.query + struct ggml_tensor * attn_q_w; + struct ggml_tensor * attn_q_b; + + // decoder.blocks.*.attn.key + struct ggml_tensor * attn_k_w; + + // decoder.blocks.*.attn.value + struct ggml_tensor * attn_v_w; + struct ggml_tensor * attn_v_b; + + // decoder.blocks.*.cross_attn_ln + struct ggml_tensor * cross_attn_ln_0_w; + struct ggml_tensor * cross_attn_ln_0_b; + + // decoder.blocks.*.cross_attn.out + struct ggml_tensor * cross_attn_ln_1_w; + struct ggml_tensor * cross_attn_ln_1_b; + + // decoder.blocks.*.cross_attn.query + struct ggml_tensor * cross_attn_q_w; + struct ggml_tensor * cross_attn_q_b; + + // decoder.blocks.*.cross_attn.key + struct ggml_tensor * cross_attn_k_w; + + // decoder.blocks.*.cross_attn.value + struct ggml_tensor * cross_attn_v_w; + struct ggml_tensor * cross_attn_v_b; + + // decoder.blocks.*.mlp_ln + struct ggml_tensor * mlp_ln_w; + struct ggml_tensor * mlp_ln_b; + + // decoder.blocks.*.mlp.0 + struct ggml_tensor * mlp_0_w; + struct ggml_tensor * mlp_0_b; + + // decoder.blocks.*.mlp.2 + struct ggml_tensor * mlp_1_w; + struct ggml_tensor * mlp_1_b; +}; + +struct whisper_kv_cache { + struct ggml_tensor * k; + struct ggml_tensor * v; + + struct ggml_context * ctx; + + std::vector buf; + + int n; // number of tokens currently in the cache +}; + +struct whisper_model { + e_model type = MODEL_UNKNOWN; + + whisper_hparams hparams; + whisper_filters filters; + + // encoder.positional_embedding + struct ggml_tensor * e_pe; + + // encoder.conv1 + struct ggml_tensor * e_conv_1_w; + struct ggml_tensor * e_conv_1_b; + + // encoder.conv2 + struct ggml_tensor * e_conv_2_w; + struct ggml_tensor * e_conv_2_b; + + // encoder.ln_post + struct ggml_tensor * e_ln_w; + struct ggml_tensor * e_ln_b; + + // decoder.positional_embedding + struct ggml_tensor * d_pe; + + // decoder.token_embedding + struct ggml_tensor * d_te; + + // decoder.ln + struct ggml_tensor * d_ln_w; + struct ggml_tensor * d_ln_b; + + std::vector layers_encoder; + std::vector layers_decoder; + + // context + struct ggml_context * ctx; + + // the model memory buffer is read-only and can be shared between processors + std::vector * buf; + + // tensors + int n_loaded; + std::map tensors; +}; + +struct whisper_sequence { + std::vector tokens; + + // the accumulated transcription in the current interation (used to truncate the tokens array) + int result_len; + + double sum_logprobs_all; // the sum of the log probabilities of the tokens + double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens) + double avg_logprobs; // the average log probability of the tokens + double entropy; // the entropy of the tokens + double score; // likelihood rank score +}; + +// TAGS: WHISPER_DECODER_INIT +struct whisper_decoder { + // each decoders keeps its own KV-cache + whisper_kv_cache kv_self; + + // the currently generated sequence of tokens + whisper_sequence sequence; + + int seek_delta; // the window shift found so far based on the decoded timestamp tokens + + bool failed; // has the current segment failed to decode? + bool completed; // has the decoder completed the current segment? + bool has_ts; // have we already sampled a non-beg timestamp token for the current segment? + + // new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab]) + std::vector probs; + std::vector logits; + std::vector logprobs; + + std::vector tokens_tmp; // used for whisper_decode calls +}; + +struct whisper_context { + int64_t t_load_us = 0; + int64_t t_mel_us = 0; + int64_t t_sample_us = 0; + int64_t t_encode_us = 0; + int64_t t_decode_us = 0; + int64_t t_start_us = 0; + + int32_t n_sample = 0; // number of tokens sampled + int32_t n_encode = 0; // number of encoder calls + int32_t n_decode = 0; // number of decoder calls + int32_t n_fail_p = 0; // number of logprob threshold failures + int32_t n_fail_h = 0; // number of entropy threshold failures + + ggml_type wtype; // weight type (FP32 or FP16) + + whisper_mel mel; + + whisper_model model; + whisper_vocab vocab; + + // cross-attention KV cache for the decoders + // shared between all decoders + whisper_kv_cache kv_cross; + + whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; + + // memory buffers used by encode / decode contexts + std::vector buf_compute; + std::vector buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS]; + + int buf_last = 0; + size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 }; + + // decode output (2-dimensional array: [n_tokens][n_vocab]) + std::vector logits; + + std::vector result_all; + std::vector prompt_past; + + // work container used to avoid memory allocations + std::vector> logits_id; + + mutable std::mt19937 rng; // used for sampling at t > 0.0 + + int lang_id; + + // [EXPERIMENTAL] token-level timestamps data + int64_t t_beg; + int64_t t_last; + whisper_token tid_last; + std::vector energy; // PCM signal energy + + // [EXPERIMENTAL] speed-up techniques + int32_t exp_n_audio_ctx; // 0 - use default + + void use_buf(struct ggml_context * ctx, int i) { +#if defined(WHISPER_USE_SCRATCH) + size_t last_size = 0; + + if (i == -1) { + last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, }); + } else { + auto & buf = buf_scratch[i]; + last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), }); + } + + if (buf_last >= 0) { + buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size); + } + + buf_last = i; +#else + (void) i; + (void) ctx; +#endif + } + + size_t get_buf_max_mem(int i) const { +#if defined(WHISPER_USE_SCRATCH) + return buf_max_size[i]; +#else + (void) i; + return 0; +#endif + } +}; + +template +static void read_safe(whisper_model_loader * loader, T & dest) { + loader->read(loader->context, &dest, sizeof(T)); + BYTESWAP_VALUE(dest); +} + +static bool kv_cache_init( + const struct whisper_hparams & hparams, + const size_t mem_bytes, + struct whisper_kv_cache & cache, + ggml_type wtype, + int n_ctx) { + cache.buf.resize(mem_bytes); + + struct ggml_init_params params; + params.mem_size = cache.buf.size(); + params.mem_buffer = cache.buf.data(); + + cache.ctx = ggml_init(params); + + if (!cache.ctx) { + fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mem = n_text_layer*n_ctx; + const int n_elements = n_text_state*n_mem; + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + return true; +} + +static bool kv_cache_reinit(struct whisper_kv_cache & cache) { + WHISPER_ASSERT(cache.ctx); + + const int n_elements = ggml_nelements(cache.k); + WHISPER_ASSERT(n_elements == ggml_nelements(cache.v)); + + const ggml_type wtype = cache.k->type; + WHISPER_ASSERT(wtype == cache.v->type); + + WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype)); + + struct ggml_init_params params; + params.mem_size = cache.buf.size(); + params.mem_buffer = cache.buf.data(); + + cache.ctx = ggml_init(params); + + if (!cache.ctx) { + fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + return true; +} + +static void kv_cache_free(struct whisper_kv_cache & cache) { + if (cache.ctx) { + ggml_free(cache.ctx); + cache.ctx = nullptr; + } +} + +// load the model from a ggml file +// +// file format: +// +// - hparams +// - pre-computed mel filters +// - vocab +// - weights +// +// see the convert-pt-to-ggml.py script for details +// +static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) { + fprintf(stderr, "%s: loading model\n", __func__); + + const int64_t t_start_us = ggml_time_us(); + + wctx.t_start_us = t_start_us; + + auto & model = wctx.model; + auto & vocab = wctx.vocab; + + // verify magic + { + uint32_t magic; + read_safe(loader, magic); + if (magic != 0x67676d6c) { + fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__); + return false; + } + } + + //load hparams + { + auto & hparams = model.hparams; + + read_safe(loader, hparams.n_vocab); + read_safe(loader, hparams.n_audio_ctx); + read_safe(loader, hparams.n_audio_state); + read_safe(loader, hparams.n_audio_head); + read_safe(loader, hparams.n_audio_layer); + read_safe(loader, hparams.n_text_ctx); + read_safe(loader, hparams.n_text_state); + read_safe(loader, hparams.n_text_head); + read_safe(loader, hparams.n_text_layer); + read_safe(loader, hparams.n_mels); + read_safe(loader, hparams.f16); + + assert(hparams.n_text_state == hparams.n_audio_state); + + if (hparams.n_audio_layer == 4) { + model.type = e_model::MODEL_TINY; + } + + if (hparams.n_audio_layer == 6) { + model.type = e_model::MODEL_BASE; + } + + if (hparams.n_audio_layer == 12) { + model.type = e_model::MODEL_SMALL; + } + + if (hparams.n_audio_layer == 24) { + model.type = e_model::MODEL_MEDIUM; + } + + if (hparams.n_audio_layer == 32) { + model.type = e_model::MODEL_LARGE; + } + + // for the big tensors, we have the option to store the data in 16-bit floats + // in order to save memory and also to speed up the computation + wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + + const size_t scale = model.hparams.f16 ? 1 : 2; + + fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); + fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + fprintf(stderr, "%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + fprintf(stderr, "%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx); + fprintf(stderr, "%s: n_text_state = %d\n", __func__, hparams.n_text_state); + fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head); + fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); + fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels); + fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16); + fprintf(stderr, "%s: type = %d\n", __func__, model.type); + + // print memory requirements + { + // this is the total memory required to run the inference + const size_t mem_required = + MEM_REQ_SCRATCH0.at (model.type) + + MEM_REQ_SCRATCH1.at (model.type) + + MEM_REQ_SCRATCH2.at (model.type) + + MEM_REQ_SCRATCH3.at (model.type) + + scale*MEM_REQ_MODEL.at (model.type) + + scale*MEM_REQ_KV_CROSS.at(model.type) + + scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); + + // this is the memory required by one decoder + const size_t mem_required_decoder = + scale*MEM_REQ_KV_SELF.at(model.type); + + fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__, + mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); + } + + // initialize all memory buffers + // always have at least one decoder + + wctx.model.buf = new std::vector(); + wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); + + if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); + return false; + } + + { + const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v); + fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + } + + if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); + return false; + } + + { + const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v); + fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + } + + wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); + + wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type)); + wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type)); + wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type)); + wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type)); + } + + // load mel filters + { + auto & filters = wctx.model.filters; + + read_safe(loader, filters.n_mel); + read_safe(loader, filters.n_fft); + + filters.data.resize(filters.n_mel * filters.n_fft); + loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float)); + BYTESWAP_FILTERS(filters); + } + + // load vocab + { + int32_t n_vocab = 0; + read_safe(loader, n_vocab); + + //if (n_vocab != model.hparams.n_vocab) { + // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", + // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); + // return false; + //} + + std::string word; + std::vector tmp; + + tmp.reserve(128); + + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + read_safe(loader, len); + + if (len > 0) { + tmp.resize(len); + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + word.assign(&tmp[0], tmp.size()); + } else { + // seems like we have an empty-string token in multi-language models (i = 50256) + //fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + word = ""; + } + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + + //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); + } + + vocab.n_vocab = model.hparams.n_vocab; + if (vocab.is_multilingual()) { + vocab.token_eot++; + vocab.token_sot++; + vocab.token_prev++; + vocab.token_solm++; + vocab.token_not++; + vocab.token_beg++; + } + + if (n_vocab < model.hparams.n_vocab) { + fprintf(stderr, "%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab); + for (int i = n_vocab; i < model.hparams.n_vocab; i++) { + if (i > vocab.token_beg) { + word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]"; + } else if (i == vocab.token_eot) { + word = "[_EOT_]"; + } else if (i == vocab.token_sot) { + word = "[_SOT_]"; + } else if (i == vocab.token_prev) { + word = "[_PREV_]"; + } else if (i == vocab.token_not) { + word = "[_NOT_]"; + } else if (i == vocab.token_beg) { + word = "[_BEG_]"; + } else { + word = "[_extra_token_" + std::to_string(i) + "]"; + } + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + } + + wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); + + wctx.logits_id.reserve(n_vocab); + + // TAGS: WHISPER_DECODER_INIT + wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx); + + wctx.decoders[0].probs.reserve (vocab.n_vocab); + wctx.decoders[0].logits.reserve (vocab.n_vocab); + wctx.decoders[0].logprobs.reserve(vocab.n_vocab); + } + + size_t ctx_size = 0; + + const ggml_type wtype = wctx.wtype; + + { + const auto & hparams = model.hparams; + + const int n_vocab = hparams.n_vocab; + + const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_state = hparams.n_audio_state; + const int n_audio_layer = hparams.n_audio_layer; + + const int n_text_ctx = hparams.n_text_ctx; + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mels = hparams.n_mels; + + // encoder + { + ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe; + + ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w + ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b + + ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w + ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b + + ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w; + ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b; + } + + // decoder + { + ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe; + + ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te; + + ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w; + ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b; + } + + // encoder layers + { + ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w + ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b + + ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w + ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b + + ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b + + ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w + ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b + + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b + + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w + + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b + + ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w + ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b + } + + // decoder layers + { + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b + + ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w + ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b + + ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b + + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b + // + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w + ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b + + ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w + ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b + } + + ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead + + fprintf(stderr, "%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + } + + // create the ggml context + { + struct ggml_init_params params; + params.mem_size = wctx.model.buf->size(); + params.mem_buffer = wctx.model.buf->data(); + + model.ctx = ggml_init(params); + if (!model.ctx) { + fprintf(stderr, "%s: ggml_init() failed\n", __func__); + return false; + } + } + + // prepare memory for the weights + { + auto & ctx = model.ctx; + + const auto & hparams = model.hparams; + + const int n_vocab = hparams.n_vocab; + + const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_state = hparams.n_audio_state; + const int n_audio_layer = hparams.n_audio_layer; + + const int n_text_ctx = hparams.n_text_ctx; + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mels = hparams.n_mels; + + model.layers_encoder.resize(n_audio_layer); + model.layers_decoder.resize(n_text_layer); + + // encoder + { + model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx); + + model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state); + model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); + + model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state); + model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); + + model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + // map by name + model.tensors["encoder.positional_embedding"] = model.e_pe; + + model.tensors["encoder.conv1.weight"] = model.e_conv_1_w; + model.tensors["encoder.conv1.bias"] = model.e_conv_1_b; + + model.tensors["encoder.conv2.weight"] = model.e_conv_2_w; + model.tensors["encoder.conv2.bias"] = model.e_conv_2_b; + + model.tensors["encoder.ln_post.weight"] = model.e_ln_w; + model.tensors["encoder.ln_post.bias"] = model.e_ln_b; + + for (int i = 0; i < n_audio_layer; ++i) { + auto & layer = model.layers_encoder[i]; + + layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state); + layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state); + + layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state); + layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + + layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + // map by name + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; + } + } + + // decoder + { + model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx); + + model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab); + + model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + // map by name + model.tensors["decoder.positional_embedding"] = model.d_pe; + + model.tensors["decoder.token_embedding.weight"] = model.d_te; + + model.tensors["decoder.ln.weight"] = model.d_ln_w; + model.tensors["decoder.ln.bias"] = model.d_ln_b; + + for (int i = 0; i < n_text_layer; ++i) { + auto & layer = model.layers_decoder[i]; + + layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state); + layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state); + + layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state); + layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + + layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + + layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + // map by name + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b; + } + } + } + + // load weights + { + size_t total_size = 0; + + model.n_loaded = 0; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; + + read_safe(loader, n_dims); + read_safe(loader, length); + read_safe(loader, ftype); + + if (loader->eof(loader->context)) { + break; + } + + int32_t nelements = 1; + int32_t ne[3] = { 1, 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + read_safe(loader, ne[i]); + nelements *= ne[i]; + } + + std::string name; + std::vector tmp(length); // create a buffer + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + name.assign(&tmp[0], tmp.size()); + + if (model.tensors.find(name) == model.tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = model.tensors[name.data()]; + if (ggml_nelements(tensor) != nelements) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", + __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]); + return false; + } + + const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t); + + if (nelements*bpe != ggml_nbytes(tensor)) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return false; + } + + loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); + BYTESWAP_TENSOR(tensor); + + //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); + total_size += ggml_nbytes(tensor); + model.n_loaded++; + } + + fprintf(stderr, "%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0); + + if (model.n_loaded == 0) { + fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + } else if (model.n_loaded != (int) model.tensors.size()) { + fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); + return false; + } + } + + wctx.rng = std::mt19937(0); + + wctx.t_load_us = ggml_time_us() - t_start_us; + + return true; +} + +// evaluate the encoder +// +// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder +// part of the transformer model and returns the encoded features +// +// - model: the model +// - n_threads: number of threads to use +// - mel_offset: offset in the mel spectrogram (i.e. audio offset) +// +static bool whisper_encode( + whisper_context & wctx, + const int mel_offset, + const int n_threads) { + const int64_t t_start_us = ggml_time_us(); + + const auto & model = wctx.model; + const auto & mel_inp = wctx.mel; + const auto & hparams = model.hparams; + + const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_state = hparams.n_audio_state; + const int n_head = hparams.n_audio_head; + const int n_layer = hparams.n_audio_layer; + + const int n_mels = hparams.n_mels; + assert(mel_inp.n_mel == n_mels); + + struct ggml_init_params params; + params.mem_size = wctx.buf_compute.size(); + params.mem_buffer = wctx.buf_compute.data(); + + struct ggml_context * ctx0 = ggml_init(params); + + wctx.use_buf(ctx0, 0); + + struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); + assert(mel->type == GGML_TYPE_F32); + { + float * dst = (float *) mel->data; + memset(dst, 0, ggml_nbytes(mel)); + + const int i0 = std::min(mel_offset, mel_inp.n_len); + const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len); + + for (int j = 0; j < mel_inp.n_mel; ++j) { + for (int i = i0; i < i1; ++i) { + dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i]; + } + } + } + + struct ggml_tensor * cur; + + // convolution + gelu + { + wctx.use_buf(ctx0, 1); + + cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel); + cur = ggml_add(ctx0, + ggml_repeat(ctx0, + model.e_conv_1_b, + cur), + cur); + + cur = ggml_gelu(ctx0, cur); + + wctx.use_buf(ctx0, 0); + + cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur); + cur = ggml_add(ctx0, + ggml_repeat(ctx0, + model.e_conv_2_b, + cur), + cur); + + cur = ggml_gelu(ctx0, cur); + } + + wctx.use_buf(ctx0, 3); + + // =================================================================== + // NOTE: experimenting with partial evaluation of the encoder (ignore) + //static int iter = -1; + //const int n_iter = 1500/n_ctx; + + //iter = (iter + 1) % n_iter; + + //if (iter == 0) { + // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); + // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); + //} + + static int iter = 0; + + const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe); + const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter; + + struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); + + cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur)); + + // =================================================================== + + // original: + //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); + + struct ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers_encoder[il]; + + // norm + { + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, inpL); + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.attn_ln_0_w, cur), + cur), + ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); + } + + // self-attention + { + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, + layer.attn_q_w, + cur); + + Qcur = ggml_add(ctx0, + ggml_repeat(ctx0, + layer.attn_q_b, + Qcur), + Qcur); + + //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + // note: no bias for Key + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, + layer.attn_k_w, + cur); + + //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, + layer.attn_v_w, + cur); + + Vcur = ggml_add(ctx0, + ggml_repeat(ctx0, + layer.attn_v_b, + Vcur), + Vcur); + + // ------ + + wctx.use_buf(ctx0, 0); + +#ifdef WHISPER_USE_FLASH_ATTN + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Kcur, + ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + struct ggml_tensor * V = + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + Vcur, + n_state/n_head, n_head, n_ctx), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head) + ); + + struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false); +#else + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Kcur, + ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + struct ggml_tensor * KQ_scaled = + ggml_scale(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) + ); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled); + + //struct ggml_tensor * V_trans = + // ggml_permute(ctx0, + // ggml_cpy(ctx0, + // Vcur, + // ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), + // 1, 2, 0, 3); + + //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); + + struct ggml_tensor * V = + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + Vcur, + n_state/n_head, n_head, n_ctx), + 0, 2, 1, 3), + ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head) + ); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max); +#endif + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + wctx.use_buf(ctx0, 1); + + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); + } + + // projection + { + wctx.use_buf(ctx0, 0); + + cur = ggml_mul_mat(ctx0, + layer.attn_ln_1_w, + cur); + + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.attn_ln_1_b, cur), + cur); + } + + wctx.use_buf(ctx0, 2); + + // add the input + cur = ggml_add(ctx0, cur, inpL); + + struct ggml_tensor * inpFF = cur; + + // feed-forward network + { + // norm + { + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, inpFF); + + wctx.use_buf(ctx0, 1); + + // cur = mlp_ln_w*cur + mlp_ln_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.mlp_ln_w, cur), + cur), + ggml_repeat(ctx0, layer.mlp_ln_b, cur)); + } + +#ifdef WHISPER_USE_FLASH_FF + wctx.use_buf(ctx0, 0); + + cur = ggml_flash_ff(ctx0, + ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)), + layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); +#else + wctx.use_buf(ctx0, 0); + + // fully connected + cur = ggml_mul_mat(ctx0, + layer.mlp_0_w, + cur); + + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.mlp_0_b, cur), + cur); + + wctx.use_buf(ctx0, 0); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + wctx.use_buf(ctx0, 1); + + // projection + cur = ggml_mul_mat(ctx0, + layer.mlp_1_w, + cur); + + wctx.use_buf(ctx0, 0); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.mlp_1_b, cur), + cur); +#endif + } + + wctx.use_buf(ctx0, 3); + + inpL = ggml_add(ctx0, cur, inpFF); + } + + cur = inpL; + + // norm + { + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, cur); + + wctx.use_buf(ctx0, 1); + + // cur = ln_f_g*cur + ln_f_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.e_ln_w, cur), + cur), + ggml_repeat(ctx0, model.e_ln_b, cur)); + } + + wctx.use_buf(ctx0, -1); + + // run the computation + { + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; + + ggml_build_forward_expand(&gf, cur); + ggml_graph_compute (ctx0, &gf); + + //ggml_graph_print(&gf); + } + + // cur + //{ + // printf("ne0 = %d\n", cur->ne[0]); + // printf("ne1 = %d\n", cur->ne[1]); + // for (int i = 0; i < 10; ++i) { + // printf("%8.4f ", ((float *)(cur->data))[i]); + // } + // printf("... "); + // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) { + // printf("%8.4f ", ((float *)(cur->data))[i]); + // } + // printf("\n"); + //} + + // pre-compute cross-attention memory + { + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; + + // TODO: hack to disconnect the encoded features from the previous graph + cur->op = GGML_OP_NONE; + cur->src0 = nullptr; + cur->src1 = nullptr; + + for (int il = 0; il < model.hparams.n_text_layer; ++il) { + auto & layer = model.layers_decoder[il]; + + wctx.use_buf(ctx0, 0); + + struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, + layer.cross_attn_k_w, + cur); + + Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, + layer.cross_attn_v_w, + cur); + + Vcross = ggml_add(ctx0, + ggml_repeat(ctx0, + layer.cross_attn_v_b, + Vcross), + Vcross); + + wctx.use_buf(ctx0, -1); + + //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx)); + struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx)); + + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v)); + } + + ggml_graph_compute(ctx0, &gf); + //ggml_graph_print(&gf); + } + + //////////////////////////////////////////////////////////////////////////// + + //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1024.0/1024.0, + // wctx.get_buf_max_mem(0)/1024.0/1024.0, + // wctx.get_buf_max_mem(1)/1024.0/1024.0, + // wctx.get_buf_max_mem(2)/1024.0/1024.0, + // wctx.get_buf_max_mem(3)/1024.0/1024.0); + + ggml_free(ctx0); + + wctx.t_encode_us += ggml_time_us() - t_start_us; + wctx.n_encode++; + + return true; +} + +// evaluate the decoder +// +// given text prompt + audio features -> computes the logits for the next token +// +// - model: the model +// - n_threads: number of threads to use +// - tokens: text prompt +// - n_tokens: number of tokens in the prompt +// - n_past: number of past tokens to prefix the prompt with +// +static bool whisper_decode( + whisper_context & wctx, + whisper_decoder & decoder, + const whisper_token * tokens, + const int n_tokens, + const int n_past, + const int n_threads) { + const int64_t t_start_us = ggml_time_us(); + + const auto & model = wctx.model; + const auto & hparams = model.hparams; + + auto & kv_self = decoder.kv_self; + + WHISPER_ASSERT(!!kv_self.ctx); + + auto & logits_out = wctx.logits; + + const int n_vocab = hparams.n_vocab; + + const int n_ctx = hparams.n_text_ctx; + const int n_state = hparams.n_text_state; + const int n_head = hparams.n_text_head; + const int n_layer = hparams.n_text_layer; + + const int N = n_tokens; + const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; + + //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); + + struct ggml_init_params params; + params.mem_size = wctx.buf_compute.size(); + params.mem_buffer = wctx.buf_compute.data(); + + struct ggml_context * ctx0 = ggml_init(params); + + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; + + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(embd->data, tokens, N*ggml_element_size(embd)); + + struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + for (int i = 0; i < N; ++i) { + ((int32_t *) position->data)[i] = n_past + i; + } + + wctx.use_buf(ctx0, 3); + + // token encoding + position encoding + struct ggml_tensor * cur = + ggml_add(ctx0, + ggml_get_rows(ctx0, model.d_te, embd), + ggml_get_rows(ctx0, model.d_pe, position)); + + struct ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers_decoder[il]; + + // norm + { + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, inpL); + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.attn_ln_0_w, cur), + cur), + ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); + } + + // self-attention + { + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, + layer.attn_q_w, + cur); + + Qcur = ggml_add(ctx0, + ggml_repeat(ctx0, + layer.attn_q_b, + Qcur), + Qcur); + + Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + // note: no bias for Key + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, + layer.attn_k_w, + cur); + + Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, + layer.attn_v_w, + cur); + + Vcur = ggml_add(ctx0, + ggml_repeat(ctx0, + layer.attn_v_b, + Vcur), + Vcur); + + // store key and value to memory + { + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past)); + + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); + } + + // ------ + + wctx.use_buf(ctx0, 0); + + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state), + n_state/n_head, n_head, n_past + N), + 0, 2, 1, 3); + + wctx.use_buf(ctx0, 1); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + wctx.use_buf(ctx0, 0); + + //struct ggml_tensor * KQ_scaled = + // ggml_scale(ctx0, + // KQ, + // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) + // ); + + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); + + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + + wctx.use_buf(ctx0, 0); + + struct ggml_tensor * V_trans = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state), + n_state/n_head, n_head, n_past + N), + 1, 2, 0, 3); + + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); + } + + // projection + { + wctx.use_buf(ctx0, 0); + + cur = ggml_mul_mat(ctx0, + layer.attn_ln_1_w, + cur); + + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.attn_ln_1_b, cur), + cur); + } + + wctx.use_buf(ctx0, 2); + + // add the input + struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL); + + // norm + { + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here + + wctx.use_buf(ctx0, 1); + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur), + cur), + ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur)); + } + + // cross-attention + { + wctx.use_buf(ctx0, 0); + + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, + layer.cross_attn_q_w, + cur); + + Qcur = ggml_add(ctx0, + ggml_repeat(ctx0, + layer.cross_attn_q_b, + Qcur), + Qcur); + + Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + // Kcross is already scaled + struct ggml_tensor * Kcross = + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state), + n_state/n_head, n_head, M); + + struct ggml_tensor * Vcross = + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state), + n_state/n_head, n_head, M); + + struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3); + + // ------ + + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)), + 0, 2, 1, 3); + + struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3); + + wctx.use_buf(ctx0, 0); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + //struct ggml_tensor * KQ_scaled = + // ggml_scale(ctx0, + // KQ, + // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) + // ); + + // no masking for cross-attention + //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); + + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); + + wctx.use_buf(ctx0, 0); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); + + wctx.use_buf(ctx0, 1); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_state, N) + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); + } + + // projection + { + wctx.use_buf(ctx0, 0); + + cur = ggml_mul_mat(ctx0, + layer.cross_attn_ln_1_w, + cur); + + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur), + cur); + } + + wctx.use_buf(ctx0, 2); + + // add the input + cur = ggml_add(ctx0, cur, inpCA); + + struct ggml_tensor * inpFF = cur; + + // feed-forward network + { + // norm + { + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, inpFF); + + wctx.use_buf(ctx0, 1); + + // cur = mlp_ln_w*cur + mlp_ln_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.mlp_ln_w, cur), + cur), + ggml_repeat(ctx0, layer.mlp_ln_b, cur)); + } + + wctx.use_buf(ctx0, 0); + + // fully connected + cur = ggml_mul_mat(ctx0, + layer.mlp_0_w, + cur); + + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.mlp_0_b, cur), + cur); + + wctx.use_buf(ctx0, 0); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + wctx.use_buf(ctx0, 1); + + // projection + cur = ggml_mul_mat(ctx0, + layer.mlp_1_w, + cur); + + wctx.use_buf(ctx0, 0); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.mlp_1_b, cur), + cur); + } + + wctx.use_buf(ctx0, 3); + + inpL = ggml_add(ctx0, cur, inpFF); + } + + cur = inpL; + + // norm + { + wctx.use_buf(ctx0, 0); + + cur = ggml_norm(ctx0, cur); + + wctx.use_buf(ctx0, 1); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.d_ln_w, cur), + cur), + ggml_repeat(ctx0, model.d_ln_b, cur)); + } + + wctx.use_buf(ctx0, 0); + + // compute logits only for the last token + // comment this line to compute logits for all N tokens + // might be useful in the future + cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); + + struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); + + wctx.use_buf(ctx0, -1); + + // run the computation + { + ggml_build_forward_expand(&gf, logits); + ggml_graph_compute (ctx0, &gf); + } + + // extract logits for all N tokens + //logits_out.resize(N*n_vocab); + //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); + + // extract logits only for the last token + logits_out.resize(n_vocab); + memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); + + if (N > 1) { + //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1024.0/1024.0, + // wctx.get_buf_max_mem(0)/1024.0/1024.0, + // wctx.get_buf_max_mem(1)/1024.0/1024.0, + // wctx.get_buf_max_mem(2)/1024.0/1024.0, + // wctx.get_buf_max_mem(3)/1024.0/1024.0); + } + + ggml_free(ctx0); + + wctx.t_decode_us += ggml_time_us() - t_start_us; + wctx.n_decode++; + + return true; +} + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +static std::string to_timestamp(int64_t t, bool comma = false) { + int64_t msec = t * 10; + int64_t hr = msec / (1000 * 60 * 60); + msec = msec - hr * (1000 * 60 * 60); + int64_t min = msec / (1000 * 60); + msec = msec - min * (1000 * 60); + int64_t sec = msec / 1000; + msec = msec - sec * 1000; + + char buf[32]; + snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); + + return std::string(buf); +} + +// naive Discrete Fourier Transform +// input is real-valued +// output is complex-valued +static void dft(const std::vector & in, std::vector & out) { + int N = in.size(); + + out.resize(N*2); + + for (int k = 0; k < N; k++) { + float re = 0; + float im = 0; + + for (int n = 0; n < N; n++) { + float angle = 2*M_PI*k*n/N; + re += in[n]*cos(angle); + im -= in[n]*sin(angle); + } + + out[k*2 + 0] = re; + out[k*2 + 1] = im; + } +} + +// Cooley-Tukey FFT +// poor man's implementation - use something better +// input is real-valued +// output is complex-valued +static void fft(const std::vector & in, std::vector & out) { + out.resize(in.size()*2); + + int N = in.size(); + + if (N == 1) { + out[0] = in[0]; + out[1] = 0; + return; + } + + if (N%2 == 1) { + dft(in, out); + return; + } + + std::vector even; + std::vector odd; + + even.reserve(N/2); + odd.reserve(N/2); + + for (int i = 0; i < N; i++) { + if (i % 2 == 0) { + even.push_back(in[i]); + } else { + odd.push_back(in[i]); + } + } + + std::vector even_fft; + std::vector odd_fft; + + fft(even, even_fft); + fft(odd, odd_fft); + + for (int k = 0; k < N/2; k++) { + float theta = 2*M_PI*k/N; + + float re = cos(theta); + float im = -sin(theta); + + float re_odd = odd_fft[2*k + 0]; + float im_odd = odd_fft[2*k + 1]; + + out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; + out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + + out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; + out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + } +} + +// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 +static bool log_mel_spectrogram( + whisper_context & wctx, + const float * samples, + const int n_samples, + const int /*sample_rate*/, + const int fft_size, + const int fft_step, + const int n_mel, + const int n_threads, + const whisper_filters & filters, + const bool speed_up, + whisper_mel & mel) { + const int64_t t_start_us = ggml_time_us(); + + // Hanning window + std::vector hann; + hann.resize(fft_size); + for (int i = 0; i < fft_size; i++) { + hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size))); + } + + mel.n_mel = n_mel; + mel.n_len = (n_samples)/fft_step; + mel.data.resize(mel.n_mel*mel.n_len); + + const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2); + + //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len); + //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate); + + std::vector workers(n_threads); + for (int iw = 0; iw < n_threads; ++iw) { + workers[iw] = std::thread([&](int ith) { + std::vector fft_in; + fft_in.resize(fft_size); + for (int i = 0; i < fft_size; i++) { + fft_in[i] = 0.0; + } + + std::vector fft_out; + fft_out.resize(2*fft_size); + + for (int i = ith; i < mel.n_len; i += n_threads) { + const int offset = i*fft_step; + + // apply Hanning window + for (int j = 0; j < fft_size; j++) { + if (offset + j < n_samples) { + fft_in[j] = hann[j]*samples[offset + j]; + } else { + fft_in[j] = 0.0; + } + } + + // FFT -> mag^2 + fft(fft_in, fft_out); + + for (int j = 0; j < fft_size; j++) { + fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]); + } + for (int j = 1; j < fft_size/2; j++) { + //if (i == 0) { + // printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]); + //} + fft_out[j] += fft_out[fft_size - j]; + } + if (i == 0) { + //for (int j = 0; j < fft_size; j++) { + // printf("%d: %e\n", j, fft_out[j]); + //} + } + + if (speed_up) { + // scale down in the frequency domain results in a speed up in the time domain + for (int j = 0; j < n_fft; j++) { + fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]); + } + } + + // mel spectrogram + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + + for (int k = 0; k < n_fft; k++) { + sum += fft_out[k]*filters.data[j*n_fft + k]; + } + if (sum < 1e-10) { + sum = 1e-10; + } + + sum = log10(sum); + + mel.data[j*mel.n_len + i] = sum; + } + } + }, iw); + } + + for (int iw = 0; iw < n_threads; ++iw) { + workers[iw].join(); + } + + // clamping and normalization + double mmax = -1e20; + for (int i = 0; i < mel.n_mel*mel.n_len; i++) { + if (mel.data[i] > mmax) { + mmax = mel.data[i]; + } + } + //printf("%s: max = %f\n", __func__, mmax); + + mmax -= 8.0; + + for (int i = 0; i < mel.n_mel*mel.n_len; i++) { + if (mel.data[i] < mmax) { + mel.data[i] = mmax; + } + + mel.data[i] = (mel.data[i] + 4.0)/4.0; + } + + wctx.t_mel_us += ggml_time_us() - t_start_us; + + return true; +} + +// split text into tokens +// +// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 +// +// Regex (Python): +// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" +// +// Regex (C++): +// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)" +// +static std::vector tokenize(const whisper_vocab & vocab, const std::string & text) { + std::vector words; + + // first split the text into words + { + std::string str = text; + std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; + + std::regex re(pat); + std::smatch m; + + while (std::regex_search(str, m, re)) { + for (auto x : m) { + words.push_back(x); + } + str = m.suffix(); + } + } + + // find the longest tokens that form the words: + std::vector tokens; + for (const auto & word : words) { + if (word.empty()) continue; + + int i = 0; + int n = word.size(); + while (i < n) { + int j = n; + while (j > i) { + auto it = vocab.token_to_id.find(word.substr(i, j-i)); + if (it != vocab.token_to_id.end()) { + tokens.push_back(it->second); + i = j; + break; + } + --j; + } + if (i == n) { + break; + } + if (j == i) { + auto sub = word.substr(i, 1); + if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) { + tokens.push_back(vocab.token_to_id.at(sub)); + } else { + fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data()); + } + ++i; + } + } + } + + return tokens; +} + +// +// interface implementation +// + +struct whisper_context * whisper_init_from_file(const char * path_model) { + whisper_model_loader loader = {}; + + fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model); + + auto fin = std::ifstream(path_model, std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_model); + return nullptr; + } + + loader.context = &fin; + loader.read = [](void * ctx, void * output, size_t read_size) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->read((char *)output, read_size); + return read_size; + }; + + loader.eof = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + return fin->eof(); + }; + + loader.close = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->close(); + }; + + return whisper_init(&loader); +} + +struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) { + struct buf_context { + uint8_t* buffer; + size_t size; + size_t current_offset; + }; + + buf_context ctx = { reinterpret_cast(buffer), buffer_size, 0 }; + whisper_model_loader loader = {}; + + fprintf(stderr, "%s: loading model from buffer\n", __func__); + + loader.context = &ctx; + + loader.read = [](void * ctx, void * output, size_t read_size) { + buf_context * buf = reinterpret_cast(ctx); + + size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset; + + memcpy(output, buf->buffer + buf->current_offset, size_to_copy); + buf->current_offset += size_to_copy; + + return size_to_copy; + }; + + loader.eof = [](void * ctx) { + buf_context * buf = reinterpret_cast(ctx); + + return buf->current_offset >= buf->size; + }; + + loader.close = [](void * /*ctx*/) { }; + + return whisper_init(&loader); +} + +struct whisper_context * whisper_init(struct whisper_model_loader * loader) { + ggml_time_init(); + + whisper_context * ctx = new whisper_context; + + if (!whisper_model_load(loader, *ctx)) { + loader->close(loader->context); + fprintf(stderr, "%s: failed to load model\n", __func__); + delete ctx; + return nullptr; + } + + loader->close(loader->context); + + return ctx; +} + +void whisper_free(struct whisper_context * ctx) { + if (ctx) { + if (ctx->model.ctx) { + ggml_free(ctx->model.ctx); + } + if (ctx->model.buf) { + delete ctx->model.buf; + } + if (ctx->kv_cross.ctx) { + ggml_free(ctx->kv_cross.ctx); + } + for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) { + if (ctx->decoders[i].kv_self.ctx) { + ggml_free(ctx->decoders[i].kv_self.ctx); + } + } + delete ctx; + } +} + +int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) { + fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + return 0; +} + +// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 +int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) { + fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + return 0; +} + +int whisper_set_mel( + struct whisper_context * ctx, + const float * data, + int n_len, + int n_mel) { + if (n_mel != WHISPER_N_MEL) { + fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL); + return -1; + } + + ctx->mel.n_len = n_len; + ctx->mel.n_mel = n_mel; + + ctx->mel.data.resize(n_len*n_mel); + memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float)); + + return 0; +} + +int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { + if (!whisper_encode(*ctx, offset, n_threads)) { + fprintf(stderr, "%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { + // TODO: add selected_decoder_id to context + const int selected_decoder_id = 0; + + if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { + fprintf(stderr, "%s: failed to eval\n", __func__); + return 1; + } + + return 0; +} + +int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) { + const auto res = tokenize(ctx->vocab, text); + + if (n_max_tokens < (int) res.size()) { + fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); + return -1; + } + + for (int i = 0; i < (int) res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +int whisper_lang_max_id() { + auto max_id = 0; + for (const auto & kv : g_lang) { + max_id = std::max(max_id, kv.second.first); + } + + return max_id; +} + +int whisper_lang_id(const char * lang) { + if (!g_lang.count(lang)) { + for (const auto & kv : g_lang) { + if (kv.second.second == lang) { + return kv.second.first; + } + } + + fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang); + return -1; + } + + return g_lang.at(lang).first; +} + +const char * whisper_lang_str(int id) { + for (const auto & kv : g_lang) { + if (kv.second.first == id) { + return kv.first.c_str(); + } + } + + fprintf(stderr, "%s: unknown language id %d\n", __func__, id); + return nullptr; +} + +int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs) { + const int seek = offset_ms/10; + + if (seek < 0) { + fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms); + return -1; + } + + if (seek >= ctx->mel.n_len) { + fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10); + return -2; + } + + // run the encoder + if (whisper_encode(ctx, seek, n_threads) != 0) { + fprintf(stderr, "%s: failed to encode\n", __func__); + return -6; + } + + const std::vector prompt = { whisper_token_sot(ctx) }; + + if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) { + fprintf(stderr, "%s: failed to decode\n", __func__); + return -7; + } + + auto & logits_id = ctx->logits_id; + logits_id.clear(); + + for (const auto & kv : g_lang) { + const auto token_lang = whisper_token_lang(ctx, kv.second.first); + logits_id.emplace_back(ctx->logits[token_lang], kv.second.first); + } + + // sort descending + { + using pair_type = std::remove_reference::type::value_type; + std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) { + return a.first > b.first; + }); + } + + // softmax + { + const auto max = logits_id[0].first; + + double sum = 0.0f; + for (auto & kv : logits_id) { + kv.first = exp(kv.first - max); + sum += kv.first; + } + + for (auto & kv : logits_id) { + kv.first /= sum; + } + } + + { + for (const auto & prob : logits_id) { + if (lang_probs) { + lang_probs[prob.second] = prob.first; + } + + //printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first); + } + } + + return logits_id[0].second; +} + +int whisper_n_len(struct whisper_context * ctx) { + return ctx->mel.n_len; +} + +int whisper_n_vocab(struct whisper_context * ctx) { + return ctx->vocab.n_vocab; +} + +int whisper_n_text_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_ctx; +} + +int whisper_n_audio_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +int whisper_is_multilingual(struct whisper_context * ctx) { + return ctx->vocab.is_multilingual() ? 1 : 0; +} + +float * whisper_get_logits(struct whisper_context * ctx) { + return ctx->logits.data(); +} + +const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) { + return ctx->vocab.id_to_token.at(token).c_str(); +} + +whisper_token whisper_token_eot(struct whisper_context * ctx) { + return ctx->vocab.token_eot; +} + +whisper_token whisper_token_sot(struct whisper_context * ctx) { + return ctx->vocab.token_sot; +} + +whisper_token whisper_token_prev(struct whisper_context * ctx) { + return ctx->vocab.token_prev; +} + +whisper_token whisper_token_solm(struct whisper_context * ctx) { + return ctx->vocab.token_solm; +} + +whisper_token whisper_token_not(struct whisper_context * ctx) { + return ctx->vocab.token_not; +} + +whisper_token whisper_token_beg(struct whisper_context * ctx) { + return ctx->vocab.token_beg; +} + +whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) { + return whisper_token_sot(ctx) + 1 + lang_id; +} + +whisper_token whisper_token_translate(void) { + return whisper_vocab::token_translate; +} + +whisper_token whisper_token_transcribe(void) { + return whisper_vocab::token_transcribe; +} + +void whisper_print_timings(struct whisper_context * ctx) { + const int64_t t_end_us = ggml_time_us(); + + const int32_t n_sample = std::max(1, ctx->n_sample); + const int32_t n_encode = std::max(1, ctx->n_encode); + const int32_t n_decode = std::max(1, ctx->n_decode); + + fprintf(stderr, "\n"); + fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h); + fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f); + fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f); + fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample); + fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode); + fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode); + fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); +} + +void whisper_reset_timings(struct whisper_context * ctx) { + ctx->t_sample_us = 0; + ctx->t_encode_us = 0; + ctx->t_decode_us = 0; +} + +const char * whisper_print_system_info(void) { + static std::string s; + + s = ""; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; + s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; + s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; + + return s.c_str(); +} + +//////////////////////////////////////////////////////////////////////////// + +struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { + struct whisper_full_params result = { + /*.strategy =*/ strategy, + + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + + /*.translate =*/ false, + /*.no_context =*/ false, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, + + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.split_on_word =*/ false, + /*.max_tokens =*/ 0, + + /*.speed_up =*/ false, + /*.audio_ctx =*/ 0, + + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, + + /*.language =*/ "en", + + /*.suppress_blank =*/ true, + /*.suppress_non_speech_tokens =*/true, + + /*.temperature =*/ 0.0f, + /*.max_initial_ts =*/ 1.0f, + /*.length_penalty =*/ -1.0f, + + /*.temperature_inc =*/ 0.2f, + /*.entropy_thold =*/ 2.4f, + /*.logprob_thold =*/ -1.0f, + /*.no_speech_thold =*/ 0.6f, + + /*.greedy =*/ { + /*.best_of =*/ -1, + }, + + /*.beam_search =*/ { + /*.beam_size =*/ -1, + + /*.patience =*/ -1.0f, + }, + + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, + + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, + }; + + switch (strategy) { + case WHISPER_SAMPLING_GREEDY: + { + result.greedy = { + /*.best_of =*/ 1, + }; + } break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + result.beam_search = { + /*.beam_size =*/ 5, + + /*.patience =*/ -1.0f, + }; + } break; + } + + return result; +} + +// forward declarations +static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context & ctx, + int i_segment, + float thold_pt, + float thold_ptsum); + +// trim from start (in place) +static inline void ltrim(std::string &s) { + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { + return !std::isspace(ch); + })); +} + +// trim from end (in place) +static inline void rtrim(std::string &s) { + s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { + return !std::isspace(ch); + }).base(), s.end()); +} + +// trim from both ends (in place) +static inline void trim(std::string &s) { + rtrim(s); + ltrim(s); +} + +static inline bool should_split_on_word(const char * txt, bool split_on_word) { + if (!split_on_word) return true; + + return txt[0] == ' '; +} + +// wrap the last segment to max_len characters +// returns the number of new segments +static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) { + auto segment = ctx.result_all.back(); + + int res = 1; + int acc = 0; + + std::string text; + + for (int i = 0; i < (int) segment.tokens.size(); i++) { + const auto & token = segment.tokens[i]; + if (token.id >= whisper_token_eot(&ctx)) { + continue; + } + + const auto txt = whisper_token_to_str(&ctx, token.id); + const int cur = strlen(txt); + + if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) { + // split here + if (split_on_word) { + trim(text); + } + + ctx.result_all.back().text = std::move(text); + ctx.result_all.back().t1 = token.t0; + ctx.result_all.back().tokens.resize(i); + + ctx.result_all.push_back({}); + ctx.result_all.back().t0 = token.t0; + ctx.result_all.back().t1 = segment.t1; + + // add tokens [i, end] to the new segment + ctx.result_all.back().tokens.insert( + ctx.result_all.back().tokens.end(), + segment.tokens.begin() + i, + segment.tokens.end()); + + acc = 0; + text = ""; + + segment = ctx.result_all.back(); + i = -1; + + res++; + } else { + acc += cur; + text += txt; + } + } + + if (split_on_word) { + trim(text); + } + ctx.result_all.back().text = std::move(text); + + return res; +} + +static const std::vector non_speech_tokens +{ + "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", + "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--", + "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", + "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯" +}; + +// process the logits for the selected decoder +// - applies logit filters +// - computes logprobs and probs +static void whisper_process_logits( + const struct whisper_context & ctx, + const struct whisper_full_params params, + struct whisper_decoder & decoder, + float temperature) { + const auto & vocab = ctx.vocab; + const auto & tokens_cur = decoder.sequence.tokens; + + const bool is_initial = tokens_cur.size() == 0; + const int n_logits = vocab.id_to_token.size(); + + WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab); + + // extract the logits for the last token + // we will be mutating and therefore we don't want to use the ctx.logits buffer directly + auto & probs = decoder.probs; + auto & logits = decoder.logits; + auto & logprobs = decoder.logprobs; + { + logits.resize(n_logits); + memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float)); + + if (temperature > 0.0f) { + for (int i = 0; i < n_logits; i++) { + logits[i] /= temperature; + } + } + + // will be populated a bit later + probs.resize(n_logits); + logprobs.resize(n_logits); + } + + // apply logit filters here + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493 + { + // suppress blank + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390 + if (params.suppress_blank) { + if (is_initial) { + logits[vocab.token_eot] = -INFINITY; + logits[vocab.token_to_id.at(" ")] = -INFINITY; + } + } + + // suppress <|notimestamps|> token + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 + logits[vocab.token_not] = -INFINITY; + + // suppress sot and solm tokens + logits[vocab.token_sot] = -INFINITY; + logits[vocab.token_solm] = -INFINITY; + + // suppress task tokens + logits[vocab.token_translate] = -INFINITY; + logits[vocab.token_transcribe] = -INFINITY; + + + // suppress non-speech tokens + // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 + if (params.suppress_non_speech_tokens) + { + for (const std::string &token : non_speech_tokens) + { + std::string suppress_tokens[] = {token, " " + token}; + for (const std::string &suppress_token : suppress_tokens) + { + if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(suppress_token)] = -INFINITY; + } + } + } + // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word + if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(" -")] = -INFINITY; + } + if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(" '")] = -INFINITY; + } + } + + // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 + { + const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; + const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg; + + //fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); + + if (last_was_timestamp) { + if (penultimate_was_timestamp) { + for (int i = vocab.token_beg; i < n_logits; ++i) { + logits[i] = -INFINITY; + } + } else { + for (int i = 0; i < vocab.token_eot; ++i) { + logits[i] = -INFINITY; + } + } + } + } + + // the initial timestamp cannot be larger than max_initial_ts + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 + if (is_initial && params.max_initial_ts > 0.0f) { + const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx; + const int tid0 = std::round(params.max_initial_ts/precision); + + for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) { + logits[i] = -INFINITY; + } + } + + // condition timestamp tokens to be increasing + // ref: https://github.com/openai/whisper/pull/831#issuecomment-1385910556 + if (decoder.has_ts) { + const int tid0 = decoder.seek_delta/2; + + for (int i = vocab.token_beg; i < vocab.token_beg + tid0; ++i) { + logits[i] = -INFINITY; + } + } + + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logsumexp += expf(logits[i] - logit_max); + } + } + logsumexp = logf(logsumexp) + logit_max; + + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logprobs[i] = logits[i] - logsumexp; + } else { + logprobs[i] = -INFINITY; + } + } + } + + // if sum of probability over timestamps is above any other token, sample timestamp + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437 + { + // logsumexp over timestamps + float timestamp_logprob = -INFINITY; + { + float logsumexp = 0.0f; + const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end()); + for (int i = vocab.token_beg; i < n_logits; ++i) { + if (logprobs[i] > -INFINITY) { + logsumexp += expf(logprobs[i] - logprob_max); + } + } + if (logsumexp > 0.0f) { + timestamp_logprob = logf(logsumexp) + logprob_max; + } + } + + const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg); + + //fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); + + if (timestamp_logprob > max_text_token_logprob) { + for (int i = 0; i < vocab.token_beg; ++i) { + logits[i] = -INFINITY; + logprobs[i] = -INFINITY; + } + } + } + } + + // compute probs + { + for (int i = 0; i < n_logits; ++i) { + if (logits[i] == -INFINITY) { + probs[i] = 0.0f; + } else { + probs[i] = expf(logprobs[i]); + } + } + } + +#if 0 + // print first 100 logits - token string : logit + for (int i = 0; i < 100; i++) { + const auto token = vocab.id_to_token.at(i); + const auto prob = probs[i]; + const auto logit = logits[i]; + const auto logprob = logprobs[i]; + printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); + } + + // "And", "and", " And", " and" + printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); + printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); + printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); + printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); + printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); + + printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); + printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); + printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); + printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); + printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); + + printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); + printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); + printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); + printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); + printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); +#endif +} + +static whisper_token_data whisper_sample_token( + whisper_context & ctx, + const whisper_decoder & decoder, + bool best) { + whisper_token_data result = { + 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, + }; + + const auto & vocab = ctx.vocab; + + const auto & probs = decoder.probs; + const auto & logprobs = decoder.logprobs; + + const int n_logits = vocab.n_vocab; + + { + double sum_ts = 0.0; + double max_ts = 0.0; + + for (int i = vocab.token_beg; i < n_logits; i++) { + if (probs[i] == -INFINITY) { + continue; + } + + sum_ts += probs[i]; + if (max_ts < probs[i]) { + max_ts = probs[i]; + result.tid = i; + } + } + + result.pt = max_ts/(sum_ts + 1e-10); + result.ptsum = sum_ts; + } + + if (best) { + for (int i = 0; i < n_logits; ++i) { + if (result.p < probs[i]) { + result.id = i; + result.p = probs[i]; + result.plog = logprobs[i]; + } + } + } else { + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + result.id = dist(ctx.rng); + result.p = probs[result.id]; + result.plog = logprobs[result.id]; + } + + if (result.id >= vocab.token_beg) { + result.tid = result.id; + result.pt = result.p; + } + + ctx.n_sample++; + + return result; +} + +static std::vector whisper_sample_token_topk( + whisper_context & ctx, + const whisper_decoder & decoder, + int k) { + const auto & vocab = ctx.vocab; + + const auto & probs = decoder.probs; + const auto & logits = decoder.logits; + const auto & logprobs = decoder.logprobs; + + const int n_logits = vocab.n_vocab; + + auto & logits_id = ctx.logits_id; + + logits_id.clear(); + for (int i = 0; i < n_logits; ++i) { + logits_id.push_back({ logits[i], i }); + } + + std::partial_sort( + logits_id.begin(), + logits_id.begin() + k, logits_id.end(), + [](const std::pair & a, const std::pair & b) { + return a.first > b.first; + }); + + std::vector result; + result.reserve(k); + + whisper_token tid = vocab.token_beg; + + float pt = 0.0; + float ptsum = 0.0; + + { + double sum_ts = 0.0; + double max_ts = 0.0; + + for (int i = vocab.token_beg; i < n_logits; i++) { + if (probs[i] == -INFINITY) { + continue; + } + + sum_ts += probs[i]; + if (max_ts < probs[i]) { + max_ts = probs[i]; + tid = i; + } + } + + pt = max_ts/(sum_ts + 1e-10); + ptsum = sum_ts; + } + + for (int i = 0; i < k; ++i) { + const auto id = logits_id[i].second; + + result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, }); + + if (result[i].id >= vocab.token_beg) { + result[i].tid = result[i].id; + result[i].pt = result[i].p; + } + } + + ctx.n_sample++; + + return result; +} + +// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192 +static void whisper_sequence_score( + const struct whisper_full_params & params, + whisper_sequence & sequence) { + if (sequence.result_len == 0) { + return; + } + + double result = 0.0f; + + for (int i = 0; i < sequence.result_len; ++i) { + result += sequence.tokens[i].plog; + } + + sequence.sum_logprobs = result; + sequence.avg_logprobs = result/sequence.result_len; + + double penalty = sequence.result_len; + + if (params.length_penalty > 0.0f) { + penalty = pow((5.0 + penalty)/6.0, params.length_penalty); + } + + sequence.score = result/penalty; + + // compute the entropy of the sequence of the last 32 tokens + { + const int n = 32; + + int cnt = 0; + double entropy = 0.0f; + + std::map token_counts; + for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) { + token_counts[sequence.tokens[i].id]++; + cnt++; + } + + for (const auto & kv : token_counts) { + const auto p = kv.second/(double)cnt; + entropy -= p*log(p); + + //WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); + } + + sequence.entropy = entropy; + } +} + +int whisper_full( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples) { + // clear old results + auto & result_all = ctx->result_all; + + result_all.clear(); + + // compute log mel spectrogram + if (params.speed_up) { + if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) { + fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); + return -1; + } + } else { + if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) { + fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + + // auto-detect language if not specified + if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) { + std::vector probs(whisper_lang_max_id() + 1, 0.0f); + + const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data()); + if (lang_id < 0) { + fprintf(stderr, "%s: failed to auto-detect language\n", __func__); + return -3; + } + ctx->lang_id = lang_id; + params.language = whisper_lang_str(lang_id); + + fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); + } + + if (params.token_timestamps) { + ctx->t_beg = 0; + ctx->t_last = 0; + ctx->tid_last = 0; + ctx->energy = get_signal_energy(samples, n_samples, 32); + } + + const int seek_start = params.offset_ms/10; + const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10); + + // if length of spectrogram is less than 1s (100 samples), then return + // basically don't process anything that is less than 1s + // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 + if (seek_end < seek_start + (params.speed_up ? 50 : 100)) { + return 0; + } + + // a set of temperatures to use + // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ] + std::vector temperatures; + if (params.temperature_inc > 0.0f) { + for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) { + temperatures.push_back(t); + } + } else { + temperatures.push_back(params.temperature); + } + + // initialize the decoders + int n_decoders = 1; + + switch (params.strategy) { + case WHISPER_SAMPLING_GREEDY: + { + n_decoders = params.greedy.best_of; + } break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size); + } break; + }; + + n_decoders = std::max(1, n_decoders); + + // TAGS: WHISPER_DECODER_INIT + for (int j = 1; j < n_decoders; j++) { + auto & decoder = ctx->decoders[j]; + + if (decoder.kv_self.ctx == nullptr) { + decoder.kv_self = ctx->decoders[0].kv_self; + if (!kv_cache_reinit(decoder.kv_self)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); + return -4; + } + + WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j); + + decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); + + decoder.probs.resize (ctx->vocab.n_vocab); + decoder.logits.resize (ctx->vocab.n_vocab); + decoder.logprobs.resize(ctx->vocab.n_vocab); + } + } + + // the accumulated text context so far + auto & prompt_past = ctx->prompt_past; + if (params.no_context) { + prompt_past.clear(); + } + + // prepend the prompt tokens to the prompt_past + if (params.prompt_tokens && params.prompt_n_tokens > 0) { + // parse tokens from the pointer + for (int i = 0; i < params.prompt_n_tokens; i++) { + prompt_past.push_back(params.prompt_tokens[i]); + } + std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); + } + + // overwrite audio_ctx, max allowed is hparams.n_audio_ctx + if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { + fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); + return -5; + } + ctx->exp_n_audio_ctx = params.audio_ctx; + + // these tokens determine the task that will be performed + std::vector prompt_init = { whisper_token_sot(ctx) }; + if (whisper_is_multilingual(ctx)) { + const int lang_id = whisper_lang_id(params.language); + ctx->lang_id = lang_id; + prompt_init.push_back(whisper_token_lang(ctx, lang_id)); + if (params.translate) { + prompt_init.push_back(whisper_token_translate()); + } else { + prompt_init.push_back(whisper_token_transcribe()); + } + } + + int progress_prev = 0; + int progress_step = 5; + + int seek = seek_start; + + std::vector prompt; + prompt.reserve(whisper_n_text_ctx(ctx)); + + // beam-search helpers + struct kv_buf { + std::vector k; + std::vector v; + }; + + std::vector kv_bufs; + + struct beam_candidate { + int decoder_idx; + int seek_delta; + + bool has_ts; + + whisper_sequence sequence; + }; + + std::vector beam_candidates; + + // main loop + while (true) { + const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); + while (progress_cur >= progress_prev + progress_step) { + progress_prev += progress_step; + if (params.print_progress) { + fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev); + } + } + + // of only 1 second left, then stop + if (seek + 100 >= seek_end) { + break; + } + + if (params.encoder_begin_callback) { + if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) { + fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__); + break; + } + } + + // encode audio features starting at offset seek + if (!whisper_encode(*ctx, seek, params.n_threads)) { + fprintf(stderr, "%s: failed to encode\n", __func__); + return -6; + } + + // if there is a very short audio segment left to process, we remove any past prompt since it tends + // to confuse the decoder and often make it repeat or hallucinate stuff + if (seek > seek_start && seek + 500 >= seek_end) { + prompt_past.clear(); + } + + int best_decoder_id = 0; + + for (int it = 0; it < (int) temperatures.size(); ++it) { + const float t_cur = temperatures[it]; + + int n_decoders_cur = 1; + + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur > 0.0f) { + n_decoders_cur = params.greedy.best_of; + } + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + if (t_cur > 0.0f) { + n_decoders_cur = params.greedy.best_of; + } else { + n_decoders_cur = params.beam_search.beam_size; + } + } break; + }; + + n_decoders_cur = std::max(1, n_decoders_cur); + + WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); + + // TAGS: WHISPER_DECODER_INIT + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + decoder.kv_self.n = 0; + + decoder.sequence.tokens.clear(); + decoder.sequence.result_len = 0; + decoder.sequence.sum_logprobs_all = 0.0; + decoder.sequence.sum_logprobs = -INFINITY; + decoder.sequence.avg_logprobs = -INFINITY; + decoder.sequence.entropy = 0.0; + decoder.sequence.score = -INFINITY; + + decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; + + decoder.failed = false; + decoder.completed = false; + decoder.has_ts = false; + } + + // init prompt and kv cache for the current iteration + // run whisper_decoder() only for decoder 0 and copy the results for the other decoders + { + prompt.clear(); + + // if we have already generated some text, use it as a prompt to condition the next generation + if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) { + int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); + + prompt = { whisper_token_prev(ctx) }; + prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); + } + + // init new transcription with sot, language (opt) and task tokens + prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); + + // print the prompt + WHISPER_PRINT_DEBUG("\n\n"); + for (int i = 0; i < (int) prompt.size(); i++) { + WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); + } + WHISPER_PRINT_DEBUG("\n\n"); + + if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { + fprintf(stderr, "%s: failed to decode\n", __func__); + return -7; + } + + { + const int64_t t_start_sample_us = ggml_time_us(); + + whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur); + + ctx->decoders[0].kv_self.n += prompt.size(); + + for (int j = 1; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); + memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); + + decoder.kv_self.n += prompt.size(); + + memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); + memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); + memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); + } + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } + } + + for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { + const int64_t t_start_sample_us = ggml_time_us(); + + // store the KV caches of all decoders when doing beam-search + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { + kv_bufs.resize(n_decoders_cur); + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k)); + kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v)); + + memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size()); + memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size()); + } + + beam_candidates.clear(); + } + + // generate new sequence candidates for each decoder + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur < 1e-6f) { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); + } else { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); + } + + decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); + + for (const auto & token : tokens_new) { + beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence }); + beam_candidates.back().sequence.tokens.push_back(token); + beam_candidates.back().sequence.sum_logprobs_all += token.plog; + + //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all); + } + } break; + }; + } + + // for beam-search, choose the top candidates and update the KV caches + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { + std::sort( + beam_candidates.begin(), + beam_candidates.end(), + [](const beam_candidate & a, const beam_candidate & b) { + return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all; + }); + + int cur_c = 0; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + auto & cur = beam_candidates[cur_c++]; + + while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) { + ++cur_c; + } + + decoder.sequence = cur.sequence; + decoder.seek_delta = cur.seek_delta; + decoder.has_ts = cur.has_ts; + + memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size()); + memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size()); + + WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", + __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all); + } + } + + // update the decoder state + // - check if the sequence is completed + // - check if the sequence is failed + // - update sliding window based on timestamp tokens + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + auto & has_ts = decoder.has_ts; + auto & failed = decoder.failed; + auto & completed = decoder.completed; + auto & seek_delta = decoder.seek_delta; + auto & result_len = decoder.sequence.result_len; + + { + const auto & token = decoder.sequence.tokens.back(); + + // timestamp token - update sliding window + if (token.id > whisper_token_beg(ctx)) { + const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + + // do not allow to go back in time + if (has_ts && seek_delta > seek_delta_new && result_len < i) { + failed = true; // TODO: maybe this is not a failure ? + continue; + } + + seek_delta = seek_delta_new; + result_len = i + 1; + has_ts = true; + } + +#ifdef WHISPER_DEBUG + { + const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; + WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", + __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); + } +#endif + + // end of segment + if (token.id == whisper_token_eot(ctx) || // end of text token + (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached + (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached + ) { + if (result_len == 0) { + if (seek + seek_delta + 100 >= seek_end) { + result_len = i + 1; + } else { + failed = true; + continue; + } + } + + if (params.single_segment) { + result_len = i + 1; + seek_delta = 100*WHISPER_CHUNK_SIZE; + } + + completed = true; + continue; + } + + // TESTS: if no tensors are loaded, it means we are running tests + if (ctx->model.n_loaded == 0) { + seek_delta = 100*WHISPER_CHUNK_SIZE; + completed = true; + continue; + } + } + + // sometimes, the decoding can get stuck in a repetition loop + // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy + if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { + failed = true; + continue; + } + } + + // check if all decoders have finished (i.e. completed or failed) + { + bool completed_all = true; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + completed_all = false; + } + + if (completed_all) { + break; + } + } + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + + // obtain logits for the next token + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.failed || decoder.completed) { + continue; + } + + decoder.tokens_tmp.resize(1); + decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id; + + //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); + + if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) { + fprintf(stderr, "%s: failed to decode\n", __func__); + return -8; + } + + { + const int64_t t_start_sample_us = ggml_time_us(); + + whisper_process_logits(*ctx, params, decoder, t_cur); + + ++decoder.kv_self.n; + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } + } + } + + // rank the resulting sequences and select the best one + { + double best_score = -INFINITY; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.failed) { + continue; + } + + decoder.sequence.tokens.resize(decoder.sequence.result_len); + whisper_sequence_score(params, decoder.sequence); + + WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", + __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); + + if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) { + WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", + __func__, j, decoder.sequence.entropy, params.entropy_thold); + + decoder.failed = true; + ctx->n_fail_h++; + + continue; + } + + if (best_score < decoder.sequence.score) { + best_score = decoder.sequence.score; + best_decoder_id = j; + } + } + + WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id); + } + + // was the decoding successful for the current temperature? + { + bool success = true; + + const auto & decoder = ctx->decoders[best_decoder_id]; + + if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) { + success = false; + ctx->n_fail_p++; + } + + if (success) { + //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { + // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); + //} + + break; + } + } + + WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); + } + + // output results through a user-provided callback + { + const auto & best_decoder = ctx->decoders[best_decoder_id]; + + const auto seek_delta = best_decoder.seek_delta; + const auto result_len = best_decoder.sequence.result_len; + + const auto & tokens_cur = best_decoder.sequence.tokens; + + //WHISPER_PRINT_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta); + + // update prompt_past + prompt_past.clear(); + if (prompt.front() == whisper_token_prev(ctx)) { + prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size()); + } + + for (int i = 0; i < result_len; ++i) { + prompt_past.push_back(tokens_cur[i].id); + } + + // store the text from this iteration + if (!tokens_cur.empty() && ctx->model.n_loaded > 0) { + int i0 = 0; + auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); + + std::string text; + + for (int i = 0; i < (int) tokens_cur.size(); i++) { + //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, + // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, + // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + + if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { + } else { + text += whisper_token_to_str(ctx, tokens_cur[i].id); + } + + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { + const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); + + if (!text.empty()) { + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + + if (params.print_realtime) { + if (params.print_timestamps) { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } else { + printf("%s", text.c_str()); + fflush(stdout); + } + } + + //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); + + result_all.push_back({ tt0, tt1, text, {} }); + for (int j = i0; j <= i; j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word); + } + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + } + } + text = ""; + while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { + i++; + } + i--; + t0 = t1; + i0 = i + 1; + } + } + + if (!text.empty()) { + const auto t1 = seek + seek_delta; + + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + + if (params.print_realtime) { + if (params.print_timestamps) { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } else { + printf("%s", text.c_str()); + fflush(stdout); + } + } + + result_all.push_back({ tt0, tt1, text, {} }); + for (int j = i0; j < (int) tokens_cur.size(); j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word); + } + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + } + } + } + + // update audio window + seek += seek_delta; + + WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta); + } + } + + return 0; +} + +int whisper_full_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + int n_processors) { + if (n_processors == 1) { + return whisper_full(ctx, params, samples, n_samples); + } + + int ret = 0; + + // prepare separate contexts for each thread + std::vector ctxs(n_processors - 1); + + for (int i = 0; i < n_processors - 1; ++i) { + auto & ctx_p = ctxs[i]; + + ctx_p = *ctx; + + ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); + + ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab); + + if (!kv_cache_reinit(ctx_p.kv_cross)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i); + return false; + } + + // TAGS: WHISPER_DECODER_INIT + for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { + if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i); + return false; + } + + ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx); + + ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab); + ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab); + ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab); + } + } + + const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000; + const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; + + // the calling thread will process the first chunk + // while the other threads will process the remaining chunks + + std::vector workers(n_processors - 1); + for (int i = 0; i < n_processors - 1; ++i) { + const int start_samples = offset_samples + (i + 1)*n_samples_per_processor; + const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor; + + auto params_cur = params; + + params_cur.offset_ms = 0; + params_cur.print_progress = false; + params_cur.print_realtime = false; + + params_cur.new_segment_callback = nullptr; + params_cur.new_segment_callback_user_data = nullptr; + + workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur); + } + + { + auto params_cur = params; + + ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor); + } + + for (int i = 0; i < n_processors - 1; ++i) { + workers[i].join(); + } + + const int64_t offset_t = (int64_t) params.offset_ms/10.0; + + // combine results into ctx->result_all + for (int i = 0; i < n_processors - 1; ++i) { + auto & results_i = ctxs[i].result_all; + + for (auto & result : results_i) { + // correct the segment timestamp taking into account the offset + result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; + result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t; + + // make sure that segments are not overlapping + if (!ctx->result_all.empty()) { + result.t0 = std::max(result.t0, ctx->result_all.back().t1); + } + + ctx->result_all.push_back(std::move(result)); + + // call the new_segment_callback for each segment + if (params.new_segment_callback) { + params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data); + } + } + + ctx->t_mel_us += ctxs[i].t_mel_us; + ctx->t_sample_us += ctxs[i].t_sample_us; + ctx->t_encode_us += ctxs[i].t_encode_us; + ctx->t_decode_us += ctxs[i].t_decode_us; + + kv_cache_free(ctx->kv_cross); + + for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { + kv_cache_free(ctx->decoders[j].kv_self); + } + } + + // average the timings + ctx->t_mel_us /= n_processors; + ctx->t_sample_us /= n_processors; + ctx->t_encode_us /= n_processors; + ctx->t_decode_us /= n_processors; + + // print information about the audio boundaries + fprintf(stderr, "\n"); + fprintf(stderr, "%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); + for (int i = 0; i < n_processors - 1; ++i) { + fprintf(stderr, "%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str()); + } + fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__); + + return ret; +} + +int whisper_full_n_segments(struct whisper_context * ctx) { + return ctx->result_all.size(); +} + +int whisper_full_lang_id(struct whisper_context * ctx) { + return ctx->lang_id; +} + +int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { + return ctx->result_all[i_segment].t0; +} + +int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) { + return ctx->result_all[i_segment].t1; +} + +const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) { + return ctx->result_all[i_segment].text.c_str(); +} + +int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) { + return ctx->result_all[i_segment].tokens.size(); +} + +const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str(); +} + +whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->result_all[i_segment].tokens[i_token].id; +} + +struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->result_all[i_segment].tokens[i_token]; +} + +float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->result_all[i_segment].tokens[i_token].p; +} + +// ================================================================================================= + +// +// Temporary interface needed for exposing ggml interface +// Will be removed in the future when ggml becomes a separate library +// + +WHISPER_API int whisper_bench_memcpy(int n_threads) { + ggml_time_init(); + + size_t n = 50; + size_t arr = n_threads > 0 ? 1024 : n_threads; // trick to avoid compiler optimizations + + // 1 GB array + const size_t size = arr*1024llu*1024llu; + + char * src = (char *) malloc(size); + char * dst = (char *) malloc(size); + + for (size_t i = 0; i < size; i++) src[i] = i; + + memcpy(dst, src, size); // heat-up + + double tsum = 0.0; + + for (size_t i = 0; i < n; i++) { + const int64_t t0 = ggml_time_us(); + + memcpy(dst, src, size); + + const int64_t t1 = ggml_time_us(); + + tsum += (t1 - t0)*1e-6; + + src[0] = rand(); + } + + fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu)); + + // needed to prevent the compile from optimizing the memcpy away + { + double sum = 0.0; + + for (size_t i = 0; i < size; i++) sum += dst[i]; + + fprintf(stderr, "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum); + } + + free(src); + free(dst); + + return 0; +} + +WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) { + ggml_time_init(); + + const int n_max = 128; + + const std::vector sizes = { + 64, 128, 256, 512, 1024, 2048, 4096, + }; + + const size_t N_max = sizes.back(); + + // a: N*N*sizeof(float) + // b: N*N*sizeof(float) + // c: N*N*sizeof(float) + // when F16 is used, there is an extra work buffer of size N*N*sizeof(float) + std::vector buf(4llu*N_max*N_max*sizeof(float) + 4*256); + + for (size_t i = 0; i < buf.size(); i++) buf[i] = i; + + for (int j = 0; j < (int) sizes.size(); j++) { + int n_fp16 = 0; + int n_fp32 = 0; + + // GFLOPS/s + double s_fp16 = 0.0; + double s_fp32 = 0.0; + + const size_t N = sizes[j]; + + for (int k = 0; k < 2; ++k) { + const ggml_type wtype = k == 0 ? GGML_TYPE_F16 : GGML_TYPE_F32; + + double & s = k == 0 ? s_fp16 : s_fp32; + int & n = k == 0 ? n_fp16 : n_fp32; + + struct ggml_init_params gparams = { + /*.mem_size =*/ buf.size(), + /*.mem_buffer =*/ buf.data(), + }; + + struct ggml_context * ctx0 = ggml_init(gparams); + + struct ggml_tensor * a = ggml_new_tensor_2d(ctx0, wtype, N, N); + struct ggml_tensor * b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N); + + struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b); + + struct ggml_cgraph gf = ggml_build_forward(c); + + gf.n_threads = n_threads; + + double tsum = 0.0; + + // heat-up + ggml_graph_compute(ctx0, &gf); + + for (int i = 0; i < n_max; ++i) { + const int64_t t0 = ggml_time_us(); + + ggml_graph_compute(ctx0, &gf); + + const int64_t t1 = ggml_time_us(); + + tsum += (t1 - t0)*1e-6; + n++; + + if (tsum > 1.0 && n >= 3) { + break; + } + } + + ggml_free(ctx0); + + s = ((2.0*N*N*N*n)/tsum)*1e-9; + } + + fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n", + N, N, s_fp16, n_fp16, s_fp32, n_fp32); + } + + return 0; +} + +// ================================================================================================= + +// ================================================================================================= + +// +// Experimental stuff below +// +// Not sure if these should be part of the library at all, because the quality of the results is not +// guaranteed. Might get removed at some point unless a robust algorithm implementation is found +// + +// ================================================================================================= + +// +// token-level timestamps +// + +static int timestamp_to_sample(int64_t t, int n_samples) { + return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); +} + +static int64_t sample_to_timestamp(int i_sample) { + return (100ll*i_sample)/WHISPER_SAMPLE_RATE; +} + +// a cost-function / heuristic that is high for text that takes longer to pronounce +// obviously, can be improved +static float voice_length(const std::string & text) { + float res = 0.0f; + + for (char c : text) { + if (c == ' ') { + res += 0.01f; + } else if (c == ',') { + res += 2.00f; + } else if (c == '.') { + res += 3.00f; + } else if (c == '!') { + res += 3.00f; + } else if (c == '?') { + res += 3.00f; + } else if (c >= '0' && c <= '9') { + res += 3.00f; + } else { + res += 1.00f; + } + } + + return res; +} + +// average the fabs of the signal +static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) { + const int hw = n_samples_per_half_window; + + std::vector result(n_samples); + + for (int i = 0; i < n_samples; i++) { + float sum = 0; + for (int j = -hw; j <= hw; j++) { + if (i + j >= 0 && i + j < n_samples) { + sum += fabs(signal[i + j]); + } + } + result[i] = sum/(2*hw + 1); + } + + return result; +} + +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context & ctx, + int i_segment, + float thold_pt, + float thold_ptsum) { + auto & segment = ctx.result_all[i_segment]; + auto & tokens = segment.tokens; + + const int n_samples = ctx.energy.size(); + + if (n_samples == 0) { + fprintf(stderr, "%s: no signal data available\n", __func__); + return; + } + + const int64_t t0 = segment.t0; + const int64_t t1 = segment.t1; + + const int n = tokens.size(); + + if (n == 0) { + return; + } + + if (n == 1) { + tokens[0].t0 = t0; + tokens[0].t1 = t1; + + return; + } + + auto & t_beg = ctx.t_beg; + auto & t_last = ctx.t_last; + auto & tid_last = ctx.tid_last; + + for (int j = 0; j < n; ++j) { + auto & token = tokens[j]; + + if (j == 0) { + if (token.id == whisper_token_beg(&ctx)) { + tokens[j ].t0 = t0; + tokens[j ].t1 = t0; + tokens[j + 1].t0 = t0; + + t_beg = t0; + t_last = t0; + tid_last = whisper_token_beg(&ctx); + } else { + tokens[j ].t0 = t_last; + } + } + + const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx)); + + tokens[j].id = token.id; + tokens[j].tid = token.tid; + tokens[j].p = token.p; + tokens[j].pt = token.pt; + tokens[j].ptsum = token.ptsum; + + tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id)); + + if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) { + if (j > 0) { + tokens[j - 1].t1 = tt; + } + tokens[j].t0 = tt; + tid_last = token.tid; + } + } + + tokens[n - 2].t1 = t1; + tokens[n - 1].t0 = t1; + tokens[n - 1].t1 = t1; + + t_last = t1; + + // find intervals of tokens with unknown timestamps + // fill the timestamps by proportionally splitting the interval based on the token voice lengths + { + int p0 = 0; + int p1 = 0; + + while (true) { + while (p1 < n && tokens[p1].t1 < 0) { + p1++; + } + + if (p1 >= n) { + p1--; + } + + //printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1); + + if (p1 > p0) { + double psum = 0.0; + for (int j = p0; j <= p1; j++) { + psum += tokens[j].vlen; + } + + //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); + + const double dt = tokens[p1].t1 - tokens[p0].t0; + + // split the time proportionally to the voice length + for (int j = p0 + 1; j <= p1; j++) { + const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum; + + tokens[j - 1].t1 = ct; + tokens[j ].t0 = ct; + } + } + + p1++; + p0 = p1; + if (p1 >= n) { + break; + } + } + } + + // fix up (just in case) + for (int j = 0; j < n - 1; j++) { + if (tokens[j].t1 < 0) { + tokens[j + 1].t0 = tokens[j].t1; + } + + if (j > 0) { + if (tokens[j - 1].t1 > tokens[j].t0) { + tokens[j].t0 = tokens[j - 1].t1; + tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1); + } + } + } + + // VAD + // expand or contract tokens based on voice activity + { + const int hw = WHISPER_SAMPLE_RATE/8; + + for (int j = 0; j < n; j++) { + if (tokens[j].id >= whisper_token_eot(&ctx)) { + continue; + } + + int s0 = timestamp_to_sample(tokens[j].t0, n_samples); + int s1 = timestamp_to_sample(tokens[j].t1, n_samples); + + const int ss0 = std::max(s0 - hw, 0); + const int ss1 = std::min(s1 + hw, n_samples); + + const int ns = ss1 - ss0; + + float sum = 0.0f; + + for (int k = ss0; k < ss1; k++) { + sum += ctx.energy[k]; + } + + const float thold = 0.5*sum/ns; + + { + int k = s0; + if (ctx.energy[k] > thold && j > 0) { + while (k > 0 && ctx.energy[k] > thold) { + k--; + } + tokens[j].t0 = sample_to_timestamp(k); + if (tokens[j].t0 < tokens[j - 1].t1) { + tokens[j].t0 = tokens[j - 1].t1; + } else { + s0 = k; + } + } else { + while (ctx.energy[k] < thold && k < s1) { + k++; + } + s0 = k; + tokens[j].t0 = sample_to_timestamp(k); + } + } + + { + int k = s1; + if (ctx.energy[k] > thold) { + while (k < n_samples - 1 && ctx.energy[k] > thold) { + k++; + } + tokens[j].t1 = sample_to_timestamp(k); + if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) { + tokens[j].t1 = tokens[j + 1].t0; + } else { + s1 = k; + } + } else { + while (ctx.energy[k] < thold && k > s0) { + k--; + } + s1 = k; + tokens[j].t1 = sample_to_timestamp(k); + } + } + } + } + + // fixed token expand (optional) + //{ + // const int t_expand = 0; + + // for (int j = 0; j < n; j++) { + // if (j > 0) { + // tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand)); + // } + // if (j < n - 1) { + // tokens[j].t1 = tokens[j].t1 + t_expand; + // } + // } + //} + + // debug info + //for (int j = 0; j < n; ++j) { + // const auto & token = tokens[j]; + // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]"; + // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, + // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id)); + + // if (tokens[j].id >= whisper_token_eot(&ctx)) { + // continue; + // } + //} +} diff --git a/bindings/ruby/ext/whisper.h b/bindings/ruby/ext/whisper.h new file mode 100644 index 0000000..7eece79 --- /dev/null +++ b/bindings/ruby/ext/whisper.h @@ -0,0 +1,379 @@ +#ifndef WHISPER_H +#define WHISPER_H + +#include +#include +#include + +#ifdef WHISPER_SHARED +# ifdef _WIN32 +# ifdef WHISPER_BUILD +# define WHISPER_API __declspec(dllexport) +# else +# define WHISPER_API __declspec(dllimport) +# endif +# else +# define WHISPER_API __attribute__ ((visibility ("default"))) +# endif +#else +# define WHISPER_API +#endif + +#define WHISPER_SAMPLE_RATE 16000 +#define WHISPER_N_FFT 400 +#define WHISPER_N_MEL 80 +#define WHISPER_HOP_LENGTH 160 +#define WHISPER_CHUNK_SIZE 30 + +#ifdef __cplusplus +extern "C" { +#endif + + // + // C interface + // + // The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads + // concurrently. + // + // Basic usage: + // + // #include "whisper.h" + // + // ... + // + // struct whisper_context * ctx = whisper_init_from_file("/path/to/ggml-base.en.bin"); + // + // if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { + // fprintf(stderr, "failed to process audio\n"); + // return 7; + // } + // + // const int n_segments = whisper_full_n_segments(ctx); + // for (int i = 0; i < n_segments; ++i) { + // const char * text = whisper_full_get_segment_text(ctx, i); + // printf("%s", text); + // } + // + // whisper_free(ctx); + // + // ... + // + // This is a demonstration of the most straightforward usage of the library. + // "pcmf32" contains the RAW audio data in 32-bit floating point format. + // + // The interface also allows for more fine-grained control over the computation, but it requires a deeper + // understanding of how the model works. + // + + struct whisper_context; + + typedef int whisper_token; + + typedef struct whisper_token_data { + whisper_token id; // token id + whisper_token tid; // forced timestamp token id + + float p; // probability of the token + float plog; // log probability of the token + float pt; // probability of the timestamp token + float ptsum; // sum of probabilities of all timestamp tokens + + // token-level timestamp data + // do not use if you haven't computed token-level timestamps + int64_t t0; // start time of the token + int64_t t1; // end time of the token + + float vlen; // voice length of the token + } whisper_token_data; + + typedef struct whisper_model_loader { + void * context; + + size_t (*read)(void * ctx, void * output, size_t read_size); + bool (*eof)(void * ctx); + void (*close)(void * ctx); + } whisper_model_loader; + + // Various functions for loading a ggml whisper model. + // Allocate (almost) all memory needed for the model. + // Return NULL on failure + WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model); + WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size); + WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); + + // Frees all memory allocated by the model. + WHISPER_API void whisper_free(struct whisper_context * ctx); + + // Convert RAW PCM audio to log mel spectrogram. + // The resulting spectrogram is stored inside the provided whisper context. + // Returns 0 on success + WHISPER_API int whisper_pcm_to_mel( + struct whisper_context * ctx, + const float * samples, + int n_samples, + int n_threads); + + // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. + // The resulting spectrogram is stored inside the provided whisper context. + // Returns 0 on success + WHISPER_API int whisper_pcm_to_mel_phase_vocoder( + struct whisper_context* ctx, + const float* samples, + int n_samples, + int n_threads); + + + // This can be used to set a custom log mel spectrogram inside the provided whisper context. + // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. + // n_mel must be 80 + // Returns 0 on success + WHISPER_API int whisper_set_mel( + struct whisper_context * ctx, + const float * data, + int n_len, + int n_mel); + + // Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context. + // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. + // offset can be used to specify the offset of the first frame in the spectrogram. + // Returns 0 on success + WHISPER_API int whisper_encode( + struct whisper_context * ctx, + int offset, + int n_threads); + + // Run the Whisper decoder to obtain the logits and probabilities for the next token. + // Make sure to call whisper_encode() first. + // tokens + n_tokens is the provided context for the decoder. + // n_past is the number of tokens to use from previous decoder calls. + // Returns 0 on success + // TODO: add support for multiple decoders + WHISPER_API int whisper_decode( + struct whisper_context * ctx, + const whisper_token * tokens, + int n_tokens, + int n_past, + int n_threads); + + // Convert the provided text into tokens. + // The tokens pointer must be large enough to hold the resulting tokens. + // Returns the number of tokens on success, no more than n_max_tokens + // Returns -1 on failure + // TODO: not sure if correct + WHISPER_API int whisper_tokenize( + struct whisper_context * ctx, + const char * text, + whisper_token * tokens, + int n_max_tokens); + + // Largest language id (i.e. number of available languages - 1) + WHISPER_API int whisper_lang_max_id(); + + // Return the id of the specified language, returns -1 if not found + // Examples: + // "de" -> 2 + // "german" -> 2 + WHISPER_API int whisper_lang_id(const char * lang); + + // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found + WHISPER_API const char * whisper_lang_str(int id); + + // Use mel data at offset_ms to try and auto-detect the spoken language + // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first + // Returns the top language id or negative on failure + // If not null, fills the lang_probs array with the probabilities of all languages + // The array must be whispe_lang_max_id() + 1 in size + // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 + WHISPER_API int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs); + + WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length + WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); + WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); + + // Token logits obtained from the last call to whisper_decode() + // The logits for the last token are stored in the last row + // Rows: n_tokens + // Cols: n_vocab + WHISPER_API float * whisper_get_logits(struct whisper_context * ctx); + + // Token Id -> String. Uses the vocabulary in the provided context + WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); + + // Special tokens + WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); + + // Task tokens + WHISPER_API whisper_token whisper_token_translate (void); + WHISPER_API whisper_token whisper_token_transcribe(void); + + // Performance information + WHISPER_API void whisper_print_timings(struct whisper_context * ctx); + WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); + + // Print system information + WHISPER_API const char * whisper_print_system_info(void); + + //////////////////////////////////////////////////////////////////////////// + + // Available sampling strategies + enum whisper_sampling_strategy { + WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreefyDecoder + WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder + }; + + // Text segment callback + // Called on every newly generated text segment + // Use the whisper_full_...() functions to obtain the text segments + typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data); + + // Encoder begin callback + // If not NULL, called before the encoder starts + // If it returns false, the computation is aborted + typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data); + + // Parameters for the whisper_full() function + // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp: + // whisper_full_default_params() + struct whisper_full_params { + enum whisper_sampling_strategy strategy; + + int n_threads; + int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms + + bool translate; + bool no_context; // do not use past transcription (if any) as initial prompt for the decoder + bool single_segment; // force single segment output (useful for streaming) + bool print_special; // print special tokens (e.g. , , , etc.) + bool print_progress; // print progress information + bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead) + bool print_timestamps; // print timestamps for each text segment when printing realtime + + // [EXPERIMENTAL] token-level timestamps + bool token_timestamps; // enable token-level timestamps + float thold_pt; // timestamp token probability threshold (~0.01) + float thold_ptsum; // timestamp token sum probability threshold (~0.01) + int max_len; // max segment length in characters + bool split_on_word; // split on word rather than on token (when used with max_len) + int max_tokens; // max tokens per segment (0 = no limit) + + // [EXPERIMENTAL] speed-up techniques + // note: these can significantly reduce the quality of the output + bool speed_up; // speed-up the audio by 2x using Phase Vocoder + int audio_ctx; // overwrite the audio context size (0 = use default) + + // tokens to provide to the whisper decoder as initial prompt + // these are prepended to any existing text context from a previous call + const whisper_token * prompt_tokens; + int prompt_n_tokens; + + // for auto-detection, set to nullptr, "" or "auto" + const char * language; + + // common decoding parameters: + bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 + bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 + + float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 + float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 + float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267 + + // fallback parameters + // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278 + float temperature_inc; + float entropy_thold; // similar to OpenAI's "compression_ratio_threshold" + float logprob_thold; + float no_speech_thold; // TODO: not implemented + + struct { + int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 + } greedy; + + struct { + int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 + + float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf + } beam_search; + + // called for every newly generated text segment + whisper_new_segment_callback new_segment_callback; + void * new_segment_callback_user_data; + + // called each time before the encoder starts + whisper_encoder_begin_callback encoder_begin_callback; + void * encoder_begin_callback_user_data; + }; + + WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); + + // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + // Uses the specified decoding strategy to obtain the text. + WHISPER_API int whisper_full( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples); + + // Split the input audio in chunks and process each chunk separately using whisper_full() + // It seems this approach can offer some speedup in some cases. + // However, the transcription accuracy can be worse at the beginning and end of each chunk. + WHISPER_API int whisper_full_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + int n_processors); + + // Number of generated text segments. + // A segment can be a few words, a sentence, or even a paragraph. + WHISPER_API int whisper_full_n_segments(struct whisper_context * ctx); + + // Language id associated with the current context + WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx); + + // Get the start and end time of the specified segment. + WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment); + + // Get the text of the specified segment. + WHISPER_API const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment); + + // Get number of tokens in the specified segment. + WHISPER_API int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment); + + // Get the token text of the specified token in the specified segment. + WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); + + // Get token data for the specified token in the specified segment. + // This contains probabilities, timestamps, etc. + WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); + + // Get the probability of the specified token in the specified segment. + WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); + + //////////////////////////////////////////////////////////////////////////// + + // Temporary helpers needed for exposing ggml interface + + WHISPER_API int whisper_bench_memcpy(int n_threads); + WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb new file mode 100644 index 0000000..82c21fe --- /dev/null +++ b/bindings/ruby/tests/test_whisper.rb @@ -0,0 +1,23 @@ +TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) +EXTDIR = File.join(TOPDIR, 'ext') +#$LIBDIR = File.join(TOPDIR, 'lib') +#$:.unshift(LIBDIR) +$:.unshift(EXTDIR) + +require 'whisper' +require 'test/unit' + +class TestWhisper < Test::Unit::TestCase + def setup + @params = Whisper::Params.new + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'for-tests-ggml-base.en.bin')) + end + + def test_autodetect + @params.auto_detection = true + end + + def test_whisper + end + +end