Adding Whisper inference example

experiments/blocking
Georgi Gerganov 2 years ago
parent f21b84cd21
commit 787efb4d2e
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -25,6 +25,8 @@ option(GGML_SANITIZE_UNDEFINED "ggml: enable undefined sanitizer" OFF)
option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE})
option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE})
option(GGML_PERF "ggml: enable perf timings" ${GGML_PERF})
# sanitizers
if (GGML_SANITIZE_THREAD)

@ -13,7 +13,21 @@ Tensor library for machine learning
- No third-party dependencies
- Zero memory allocations during runtime
## Example - GPT inference
## Whisper inference (example)
With ggml you can efficiently run [Whisper](examples/whisper) inference on the CPU.
Memory requirements:
| Model | Mem |
| --- | --- |
| tiny.en | ~460 MB |
| base.en | ~620 MB |
| small.en | ~1.3 GB |
| medium.en | ~2.8 GB |
| large | ~4.9 GB |
## GPT inference (example)
With ggml you can efficiently run [GPT-2](examples/gpt-2) and [GPT-J](examples/gpt-j) inference on the CPU.

@ -3,3 +3,4 @@ target_include_directories(ggml_utils PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
add_subdirectory(gpt-2)
add_subdirectory(gpt-j)
add_subdirectory(whisper)

File diff suppressed because it is too large Load Diff

@ -366,8 +366,6 @@ bool gpt2_eval(
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
const int d_key = n_embd/n_head;
static size_t buf_size = 256u*1024*1024;
static void * buf = malloc(buf_size);
@ -474,6 +472,18 @@ bool gpt2_eval(
n_embd/n_head, n_head, n_past + N),
0, 2, 1, 3);
// GG: flash attention
//struct ggml_tensor * V =
// ggml_cpy(ctx0,
// ggml_permute(ctx0,
// ggml_reshape_3d(ctx0,
// ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
// n_embd/n_head, n_head, n_past + N),
// 1, 2, 0, 3),
// ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
//struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
// K * Q
// [n_past + N, N, 12]
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
@ -616,7 +626,7 @@ bool gpt2_eval(
// [ 768, N] - inpL
inpL = ggml_mul_mat(ctx0, model.wte, inpL);
// to logits
// logits -> probs
inpL = ggml_soft_max(ctx0, inpL);
// run the computation

@ -558,7 +558,7 @@ bool gptj_eval(
inpL);
}
// to logits
// logits -> probs
inpL = ggml_soft_max(ctx0, inpL);
// run the computation

@ -57,6 +57,25 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
fprintf(stderr, "\n");
}
std::string gpt_random_prompt(std::mt19937 & rng) {
const int r = rng() % 10;
switch (r) {
case 0: return "So";
case 1: return "Once upon a time";
case 2: return "When";
case 3: return "The";
case 4: return "After";
case 5: return "If";
case 6: return "import";
case 7: return "He";
case 8: return "She";
case 9: return "They";
default: return "To";
}
return "The";
}
void replace(std::string & str, const std::string & needle, const std::string & replacement) {
size_t pos = 0;
while ((pos = str.find(needle, pos)) != std::string::npos) {
@ -65,7 +84,6 @@ void replace(std::string & str, const std::string & needle, const std::string &
}
}
// poor-man's JSON parsing
std::map<std::string, int32_t> json_parse(const std::string & fname) {
std::map<std::string, int32_t> result;
@ -157,25 +175,6 @@ std::map<std::string, int32_t> json_parse(const std::string & fname) {
return result;
}
std::string gpt_random_prompt(std::mt19937 & rng) {
const int r = rng() % 10;
switch (r) {
case 0: return "So";
case 1: return "Once upon a time";
case 2: return "When";
case 3: return "The";
case 4: return "After";
case 5: return "If";
case 6: return "import";
case 7: return "He";
case 8: return "She";
case 9: return "They";
default: return "To";
}
return "The";
}
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
std::vector<std::string> words;

@ -28,10 +28,10 @@ struct gpt_params {
std::string prompt;
};
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
std::string gpt_random_prompt(std::mt19937 & rng);
//

@ -0,0 +1,6 @@
#
# whisper
set(TEST_TARGET whisper)
add_executable(${TEST_TARGET} main.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils)

@ -0,0 +1,29 @@
# whisper
Port of [OpenAI's Whisper](https://github.com/openai/whisper) ASR model in C/C++ using
[ggml](https://github.com/ggerganov/ggml)
## More info
Checkout https://github.com/ggerganov/whisper.cpp
## Memory usage
| Model | Mem |
| --- | --- |
| tiny.en | ~460 MB |
| base.en | ~620 MB |
| small.en | ~1.3 GB |
| medium.en | ~2.8 GB |
| large | ~4.9 GB |
## ggml format
The original models are converted to a custom binary format. This allows to pack everything needed into a single file:
- model parameters
- mel filters
- vocabulary
- weights
For more details, see the conversion script [convert-pt-to-ggml.py](convert-pt-to-ggml.py)

@ -0,0 +1,328 @@
# Convert Whisper transformer model from PyTorch to ggml format
#
# Usage: python convert-pt-to-ggml.py ~/.cache/whisper/medium.pt ~/path/to/repo/whisper/ ./models/whisper-medium
#
# You need to clone the original repo in ~/path/to/repo/whisper/
#
# git clone https://github.com/openai/whisper ~/path/to/repo/whisper/
#
# It is used to various assets needed by the algorithm:
#
# - tokenizer
# - mel filters
#
# Also, you need to have the original models in ~/.cache/whisper/
# See the original repo for more details.
#
# This script loads the specified model and whisper assets and saves them in ggml format.
# The output is a single binary file containing the following information:
#
# - hparams
# - mel filters
# - tokenizer vocab
# - model variables
#
# For each variable, write the following:
#
# - Number of dimensions (int)
# - Name length (int)
# - Dimensions (int[n_dims])
# - Name (char[name_length])
# - Data (float[n_dims])
#
import io
import os
import sys
import struct
import json
import code
import torch
import numpy as np
from transformers import GPTJForCausalLM
from transformers import GPT2TokenizerFast
# ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L10-L110
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"iw": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
}
# ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L273-L292
def build_tokenizer(path_to_whisper_repo: str, name: str = "gpt2"):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
path = os.path.join(path_to_whisper_repo, "whisper/assets", name)
tokenizer = GPT2TokenizerFast.from_pretrained(path)
specials = [
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nocaptions|>",
"<|notimestamps|>",
]
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
return tokenizer
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
if len(sys.argv) < 4:
print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n")
sys.exit(1)
fname_inp = sys.argv[1]
dir_whisper = sys.argv[2]
dir_out = sys.argv[3]
# try to load PyTorch binary data
try:
model_bytes = open(fname_inp, "rb").read()
with io.BytesIO(model_bytes) as fp:
checkpoint = torch.load(fp, map_location="cpu")
except:
print("Error: failed to load PyTorch model file: %s" % fname_inp)
sys.exit(1)
hparams = checkpoint["dims"]
print("hparams:", hparams)
list_vars = checkpoint["model_state_dict"]
#print(list_vars['encoder.positional_embedding'])
#print(list_vars['encoder.conv1.weight'])
#print(list_vars['encoder.conv1.weight'].shape)
# load mel filters
n_mels = hparams["n_mels"]
with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as f:
filters = torch.from_numpy(f[f"mel_{n_mels}"])
#print (filters)
#code.interact(local=locals())
multilingual = hparams["n_vocab"] == 51865
tokenizer = build_tokenizer(dir_whisper, multilingual and "multilingual" or "gpt2")
#print(tokenizer)
#print(tokenizer.name_or_path)
#print(len(tokenizer.additional_special_tokens))
dir_tokenizer = tokenizer.name_or_path
# output in the same directory as the model
fname_out = dir_out + "/ggml-model.bin"
with open(dir_tokenizer + "/vocab.json", "r") as f:
tokens = json.load(f)
# use 16-bit or 32-bit floats
use_f16 = True
if len(sys.argv) > 4:
use_f16 = False
fname_out = dir_out + "/ggml-model-f32.bin"
fout = open(fname_out, "wb")
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
fout.write(struct.pack("i", hparams["n_vocab"]))
fout.write(struct.pack("i", hparams["n_audio_ctx"]))
fout.write(struct.pack("i", hparams["n_audio_state"]))
fout.write(struct.pack("i", hparams["n_audio_head"]))
fout.write(struct.pack("i", hparams["n_audio_layer"]))
fout.write(struct.pack("i", hparams["n_text_ctx"]))
fout.write(struct.pack("i", hparams["n_text_state"]))
fout.write(struct.pack("i", hparams["n_text_head"]))
fout.write(struct.pack("i", hparams["n_text_layer"]))
fout.write(struct.pack("i", hparams["n_mels"]))
fout.write(struct.pack("i", use_f16))
# write mel filters
fout.write(struct.pack("i", filters.shape[0]))
fout.write(struct.pack("i", filters.shape[1]))
for i in range(filters.shape[0]):
for j in range(filters.shape[1]):
fout.write(struct.pack("f", filters[i][j]))
byte_encoder = bytes_to_unicode()
byte_decoder = {v:k for k, v in byte_encoder.items()}
fout.write(struct.pack("i", len(tokens)))
for key in tokens:
text = bytearray([byte_decoder[c] for c in key]).decode('utf-8', errors='replace').encode('utf-8')
fout.write(struct.pack("i", len(text)))
fout.write(text)
for name in list_vars.keys():
data = list_vars[name].squeeze().numpy()
print("Processing variable: " + name + " with shape: ", data.shape)
# reshape conv bias from [n] to [n, 1]
if name == "encoder.conv1.bias" or \
name == "encoder.conv2.bias":
data = data.reshape(data.shape[0], 1)
print(" Reshaped variable: " + name + " to shape: ", data.shape)
n_dims = len(data.shape);
# looks like the whisper models are in f16 by default
# so we need to convert the small tensors to f32 until we fully support f16 in ggml
# ftype == 0 -> float32, ftype == 1 -> float16
ftype = 1;
if use_f16:
if n_dims < 2 or \
name == "encoder.conv1.bias" or \
name == "encoder.conv2.bias" or \
name == "encoder.positional_embedding" or \
name == "decoder.positional_embedding":
ftype = 0
data = data.astype(np.float32)
print(" Converting to float32")
data = data.astype(np.float32)
ftype = 0
else:
data = data.astype(np.float32)
ftype = 0
#if name.startswith("encoder"):
# if name.endswith("mlp.0.weight") or \
# name.endswith("mlp.2.weight"):
# print(" Transposing")
# data = data.transpose()
# header
str = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str), ftype))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(str);
# data
data.tofile(fout)
fout.close()
print("Done. Output file: " + fname_out)
print("")

File diff suppressed because it is too large Load Diff

@ -12,6 +12,7 @@ extern "C" {
#define GGML_MAX_NODES 4096
#define GGML_MAX_PARAMS 16
#define GGML_MAX_CONTEXTS 16
#define GGML_MAX_OPT 4
#ifdef __ARM_NEON
// we use the built-in 16-bit float type
@ -68,6 +69,11 @@ enum ggml_op {
GGML_OP_DIAG_MASK_INF,
GGML_OP_SOFT_MAX,
GGML_OP_ROPE,
GGML_OP_CONV_1D_1S,
GGML_OP_CONV_1D_2S,
GGML_OP_FLASH_ATTN,
GGML_OP_FLASH_FF,
GGML_OP_COUNT,
};
@ -91,6 +97,7 @@ struct ggml_tensor {
struct ggml_tensor * grad;
struct ggml_tensor * src0;
struct ggml_tensor * src1;
struct ggml_tensor * opt[GGML_MAX_OPT];
// thread scheduling
int n_tasks;
@ -180,14 +187,19 @@ struct ggml_tensor * ggml_new_tensor_4d(
int ne2,
int ne3);
struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
@ -383,6 +395,35 @@ struct ggml_tensor * ggml_rope(
int n_dims,
int mode);
// padding = 1
// TODO: we don't support extra parameters for now
// that's why we are hard-coding the stride, padding, and dilation
// not great ..
struct ggml_tensor * ggml_conv_1d_1s(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * ggml_conv_1d_2s(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * ggml_flash_attn(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
bool masked);
struct ggml_tensor * ggml_flash_ff(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b0,
struct ggml_tensor * b1,
struct ggml_tensor * c0,
struct ggml_tensor * c1);
//
// automatic differentiation
//

@ -59,6 +59,7 @@ add_library(${TARGET}
target_include_directories(${TARGET} PUBLIC
.
../include
../include/ggml
)
target_link_libraries(${TARGET} PUBLIC m ${GGML_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save