From 7f32376b700ca210712c975eafa86790f7fe302f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Mar 2023 19:01:07 +0200 Subject: [PATCH] llama : initial working FP16 + 4-bit Q4_0 --- examples/CMakeLists.txt | 1 + examples/llama/CMakeLists.txt | 13 + examples/llama/convert-pth-to-ggml.py | 14 +- examples/llama/main.cpp | 747 ++++++++++++++++++++++++++ examples/llama/quantize.cpp | 330 ++++++++++++ examples/utils.cpp | 32 ++ examples/utils.h | 4 + include/ggml/ggml.h | 5 + src/ggml.c | 156 +++++- 9 files changed, 1295 insertions(+), 7 deletions(-) create mode 100644 examples/llama/CMakeLists.txt create mode 100644 examples/llama/main.cpp create mode 100644 examples/llama/quantize.cpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 5f7f3a4..350065f 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -4,3 +4,4 @@ target_include_directories(ggml_utils PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) add_subdirectory(gpt-2) add_subdirectory(gpt-j) add_subdirectory(whisper) +add_subdirectory(llama) diff --git a/examples/llama/CMakeLists.txt b/examples/llama/CMakeLists.txt new file mode 100644 index 0000000..d40575b --- /dev/null +++ b/examples/llama/CMakeLists.txt @@ -0,0 +1,13 @@ +# +# llama + +set(TEST_TARGET llama) +add_executable(${TEST_TARGET} main.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils) + +# +# llama-quantize + +set(TEST_TARGET llama-quantize) +add_executable(${TEST_TARGET} quantize.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils) diff --git a/examples/llama/convert-pth-to-ggml.py b/examples/llama/convert-pth-to-ggml.py index a966ce3..d0a187c 100644 --- a/examples/llama/convert-pth-to-ggml.py +++ b/examples/llama/convert-pth-to-ggml.py @@ -73,18 +73,26 @@ fout.write(struct.pack("i", hparams["dim"])) fout.write(struct.pack("i", hparams["multiple_of"])) fout.write(struct.pack("i", hparams["n_heads"])) fout.write(struct.pack("i", hparams["n_layers"])) +fout.write(struct.pack("i", 64)) # rot fout.write(struct.pack("i", ftype)) # Is this correct?? for i in range(32000): - text = tokenizer.decode(i) + # TODO: this is probably wrong - not sure how this tokenizer works + text = tokenizer.decode([29889, i]).encode('utf-8') + # remove the first byte (it's always '.') + text = text[1:] fout.write(struct.pack("i", len(text))) - fout.write(text.encode('utf-8')) + fout.write(text) for k, v in model.items(): name = k shape = v.shape + # skip layers.X.attention.inner_attention.rope.freqs + if name[-5:] == "freqs": + continue + print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype) #data = tf.train.load_variable(dir_model, name).squeeze() @@ -107,7 +115,7 @@ for k, v in model.items(): # default type is fp16 ftype_cur = 1 - if ftype == 0: + if ftype == 0 or n_dims == 1: print(" Converting to float32") data = data.astype(np.float32) ftype_cur = 0 diff --git a/examples/llama/main.cpp b/examples/llama/main.cpp new file mode 100644 index 0000000..23d6357 --- /dev/null +++ b/examples/llama/main.cpp @@ -0,0 +1,747 @@ +#include "ggml/ggml.h" + +#include "utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +// default hparams (LLaMA 7B) +struct llama_hparams { + int32_t n_vocab = 32000; + int32_t n_ctx = 512; // this is provided as user input? + int32_t n_embd = 4096; + int32_t n_mult = 256; + int32_t n_head = 32; + int32_t n_layer = 32; + int32_t n_rot = 64; + int32_t f16 = 1; +}; + +struct llama_layer { + // normalization + struct ggml_tensor * attention_norm; + + // attention + struct ggml_tensor * wq; + struct ggml_tensor * wk; + struct ggml_tensor * wv; + struct ggml_tensor * wo; + + // normalization + struct ggml_tensor * ffn_norm; + + // ff + struct ggml_tensor * w1; + struct ggml_tensor * w2; + struct ggml_tensor * w3; +}; + +struct llama_model { + llama_hparams hparams; + + struct ggml_tensor * tok_embeddings; + + struct ggml_tensor * norm; + struct ggml_tensor * output; + + std::vector layers; + + // key + value memory + struct ggml_tensor * memory_k; + struct ggml_tensor * memory_v; + + // + struct ggml_context * ctx; + std::map tensors; +}; + +// load the model's weights from a file +bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) { + printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); + + auto fin = std::ifstream(fname, std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); + return false; + } + + // verify magic + { + uint32_t magic; + fin.read((char *) &magic, sizeof(magic)); + if (magic != 0x67676d6c) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); + return false; + } + } + + int n_ff = 0; + + // load hparams + { + auto & hparams = model.hparams; + + fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + //fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx)); + fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd)); + fin.read((char *) &hparams.n_mult, sizeof(hparams.n_mult)); + fin.read((char *) &hparams.n_head, sizeof(hparams.n_head)); + fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); + fin.read((char *) &hparams.n_rot, sizeof(hparams.n_rot)); + fin.read((char *) &hparams.f16, sizeof(hparams.f16)); + + hparams.n_ctx = n_ctx; + + n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; + + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); + printf("%s: n_embd = %d\n", __func__, hparams.n_embd); + printf("%s: n_mult = %d\n", __func__, hparams.n_mult); + printf("%s: n_head = %d\n", __func__, hparams.n_head); + printf("%s: n_layer = %d\n", __func__, hparams.n_layer); + printf("%s: n_rot = %d\n", __func__, hparams.n_rot); + printf("%s: f16 = %d\n", __func__, hparams.f16); + printf("%s: n_ff = %d\n", __func__, n_ff); + } + + // load vocab + { + const int32_t n_vocab = model.hparams.n_vocab; + + if (n_vocab != model.hparams.n_vocab) { + fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", + __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); + return false; + } + + std::string word; + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + fin.read((char *) &len, sizeof(len)); + + word.resize(len); + fin.read((char *) word.data(), len); + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + + //if (i < 30000) { + // printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); + //} + } + } + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + ggml_type wtype = GGML_TYPE_COUNT; + switch (model.hparams.f16) { + case 0: wtype = GGML_TYPE_F32; break; + case 1: wtype = GGML_TYPE_F16; break; + case 2: wtype = GGML_TYPE_Q4_0; break; + case 3: wtype = GGML_TYPE_Q4_1; break; + default: + { + fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n", + __func__, fname.c_str(), model.hparams.f16); + return false; + } + } + + const ggml_type wtype2 = GGML_TYPE_F32; + + auto & ctx = model.ctx; + + size_t ctx_size = 0; + + { + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + + ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // tok_embeddings + + ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // norm + + ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // output + + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // attention_norm + + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wq + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wk + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wv + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wo + + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm + + ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w1 + ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2 + ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3 + + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v + + ctx_size += (5 + 10*n_layer)*256; // object overhead + + printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + } + + // create the ggml context + { + struct ggml_init_params params = { + .mem_size = ctx_size, + .mem_buffer = NULL, + }; + + model.ctx = ggml_init(params); + if (!model.ctx) { + fprintf(stderr, "%s: ggml_init() failed\n", __func__); + return false; + } + } + + // prepare memory for the weights + { + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + + model.layers.resize(n_layer); + + model.tok_embeddings = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); + + model.norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + model.output = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); + + // map by name + model.tensors["tok_embeddings.weight"] = model.tok_embeddings; + + model.tensors["norm.weight"] = model.norm; + model.tensors["output.weight"] = model.output; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + + layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + layer.wq = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.wk = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.wv = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.wo = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + + layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + layer.w1 = ggml_new_tensor_2d(ctx, wtype, n_embd, n_ff); + layer.w2 = ggml_new_tensor_2d(ctx, wtype, n_ff, n_embd); + layer.w3 = ggml_new_tensor_2d(ctx, wtype, n_embd, n_ff); + + // map by name + model.tensors["layers." + std::to_string(i) + ".attention_norm.weight"] = layer.attention_norm; + + model.tensors["layers." + std::to_string(i) + ".attention.wq.weight"] = layer.wq; + model.tensors["layers." + std::to_string(i) + ".attention.wk.weight"] = layer.wk; + model.tensors["layers." + std::to_string(i) + ".attention.wv.weight"] = layer.wv; + model.tensors["layers." + std::to_string(i) + ".attention.wo.weight"] = layer.wo; + + model.tensors["layers." + std::to_string(i) + ".ffn_norm.weight"] = layer.ffn_norm; + + model.tensors["layers." + std::to_string(i) + ".feed_forward.w1.weight"] = layer.w1; + model.tensors["layers." + std::to_string(i) + ".feed_forward.w2.weight"] = layer.w2; + model.tensors["layers." + std::to_string(i) + ".feed_forward.w3.weight"] = layer.w3; + } + } + + // key + value memory + { + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + + const int n_mem = n_layer*n_ctx; + const int n_elements = n_embd*n_mem; + + model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); + model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); + + const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); + + printf("%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem); + } + + // load weights + { + int n_tensors = 0; + size_t total_size = 0; + + printf("%s: ", __func__); + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; + + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ftype), sizeof(ftype)); + + if (fin.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(length, 0); + fin.read(&name[0], length); + + if (model.tensors.find(name.data()) == model.tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = model.tensors[name.data()]; + if (ggml_nelements(tensor) != nelements) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", + __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]); + return false; + } + + if (0) { + static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; + printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ftype_str[ftype], ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor)); + } + + size_t bpe = 0; + + switch (ftype) { + case 0: bpe = ggml_type_size(GGML_TYPE_F32); break; + case 1: bpe = ggml_type_size(GGML_TYPE_F16); break; + case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break; + case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break; + default: + { + fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); + return false; + } + }; + + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return false; + } + + fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); + + //printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); + total_size += ggml_nbytes(tensor); + if (++n_tensors % 8 == 0) { + printf("."); + fflush(stdout); + } + } + + printf(" done\n"); + + printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors); + } + + fin.close(); + + return true; +} + +// evaluate the transformer +// +// - model: the model +// - n_threads: number of threads to use +// - n_past: the context size so far +// - embd_inp: the embeddings of the tokens in the context +// - embd_w: the predicted logits for the next token +// +// The GPT-J model requires about 16MB of memory per input token. +// +bool llama_eval( + const llama_model & model, + const int n_threads, + const int n_past, + const std::vector & embd_inp, + std::vector & embd_w, + size_t & mem_per_token) { + const int N = embd_inp.size(); + + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_head = hparams.n_head; + const int n_vocab = hparams.n_vocab; + const int n_rot = hparams.n_rot; + + const int d_key = n_embd/n_head; + + static size_t buf_size = 256u*1024*1024; + static void * buf = malloc(buf_size); + + if (mem_per_token > 0 && mem_per_token*N > buf_size) { + const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead + //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); + + // reallocate + buf_size = buf_size_new; + buf = realloc(buf, buf_size); + if (buf == nullptr) { + fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); + return false; + } + } + + struct ggml_init_params params = { + .mem_size = buf_size, + .mem_buffer = buf, + }; + + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph gf = { .n_threads = n_threads }; + + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd)); + + struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + struct ggml_tensor * cur; + + // norm + { + cur = ggml_norm(ctx0, inpL); + + // cur = attention_norm*cur + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].attention_norm, cur), + cur); + } + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + + // store key and value to memory + if (N >= 1) { + struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past)); + + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); + } + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_rope(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)), + n_past, n_rot, 0), + 0, 2, 1, 3); + + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_rope(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd), + n_embd/n_head, n_head, n_past + N), + n_past, n_rot, 1), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ggml_tensor * KQ_scaled = + ggml_scale(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head)) + ); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + struct ggml_tensor * V_trans = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), + n_embd/n_head, n_head, n_past + N), + 1, 2, 0, 3); + + // KQV = transpose(V) * KQ_soft_max + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + // projection (no bias) + cur = ggml_mul_mat(ctx0, + model.layers[il].wo, + cur); + } + + struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); + + // feed-forward network + { + // norm + { + cur = ggml_norm(ctx0, inpFF); + + // cur = ffn_norm*cur + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].ffn_norm, cur), + cur); + } + + struct ggml_tensor * tmp = ggml_mul_mat(ctx0, + model.layers[il].w3, + cur); + + + cur = ggml_mul_mat(ctx0, + model.layers[il].w1, + cur); + + // SILU activation + cur = ggml_silu(ctx0, cur); + + cur = ggml_mul(ctx0, cur, tmp); + + cur = ggml_mul_mat(ctx0, + model.layers[il].w2, + cur); + } + + cur = ggml_add(ctx0, cur, inpFF); + + // input for next layer + inpL = cur; + } + + // norm + { + inpL = ggml_norm(ctx0, inpL); + + // inpL = norm*inpL + inpL = ggml_mul(ctx0, + ggml_repeat(ctx0, model.norm, inpL), + inpL); + } + + // lm_head + { + inpL = ggml_mul_mat(ctx0, model.output, inpL); + } + + // logits -> probs + //inpL = ggml_soft_max(ctx0, inpL); + + // run the computation + ggml_build_forward_expand(&gf, inpL); + ggml_graph_compute (ctx0, &gf); + + //if (n_past%100 == 0) { + // ggml_graph_print (&gf); + // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); + //} + + //embd_w.resize(n_vocab*N); + //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); + + // return result for just the last token + embd_w.resize(n_vocab); + memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + + if (mem_per_token == 0) { + mem_per_token = ggml_used_mem(ctx0)/N; + } + //printf("used_mem = %zu\n", ggml_used_mem(ctx0)); + + ggml_free(ctx0); + + return true; +} + +int main(int argc, char ** argv) { + const int64_t t_main_start_us = ggml_time_us(); + + gpt_params params; + params.model = "models/llama-7B/ggml-model.bin"; + + if (gpt_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.seed < 0) { + params.seed = time(NULL); + } + + printf("%s: seed = %d\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + if (params.prompt.empty()) { + params.prompt = gpt_random_prompt(rng); + } + + int64_t t_load_us = 0; + + gpt_vocab vocab; + llama_model model; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!llama_model_load(params.model, model, vocab, 512)) { // TODO: set context from user input ?? + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return 1; + } + + t_load_us = ggml_time_us() - t_start_us; + } + + int n_past = 0; + + int64_t t_sample_us = 0; + int64_t t_predict_us = 0; + + std::vector logits; + + // tokenize the prompt + std::vector embd_inp = ::llama_tokenize(vocab, params.prompt, true); + + params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); + + printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + for (int i = 0; i < (int) embd_inp.size(); i++) { + printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str()); + } + printf("\n\n"); + + std::vector embd; + + // determine the required inference memory per token: + size_t mem_per_token = 0; + llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); + + for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { + // predict + if (embd.size() > 0) { + const int64_t t_start_us = ggml_time_us(); + + if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { + printf("Failed to predict\n"); + return 1; + } + + t_predict_us += ggml_time_us() - t_start_us; + } + + n_past += embd.size(); + embd.clear(); + + if (i >= embd_inp.size()) { + // sample next token + const int top_k = params.top_k; + const float top_p = params.top_p; + const float temp = params.temp; + + const int n_vocab = model.hparams.n_vocab; + + gpt_vocab::id id = 0; + + { + const int64_t t_start_sample_us = ggml_time_us(); + + id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); + + t_sample_us += ggml_time_us() - t_start_sample_us; + } + + // add it to the context + embd.push_back(id); + } else { + // if here, it means we are still processing the input prompt + for (int k = i; k < embd_inp.size(); k++) { + embd.push_back(embd_inp[k]); + if (embd.size() > params.n_batch) { + break; + } + } + i += embd.size() - 1; + } + + // display text + for (auto id : embd) { + printf("%s", vocab.id_to_token[id].c_str()); + } + fflush(stdout); + + // end of text token + if (embd.back() == 50256) { + break; + } + } + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + printf("\n\n"); + printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); + printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); + printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); + printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + } + + ggml_free(model.ctx); + + return 0; +} diff --git a/examples/llama/quantize.cpp b/examples/llama/quantize.cpp new file mode 100644 index 0000000..3453a2d --- /dev/null +++ b/examples/llama/quantize.cpp @@ -0,0 +1,330 @@ +#include "ggml/ggml.h" + +#include "utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// TODO: move somewhere else +#define QK 32 + +// default hparams (LLaMA76B) +struct llama_hparams { + int32_t n_vocab = 32000; + int32_t n_ctx = 512; // this is provided as user input? + int32_t n_embd = 4096; + int32_t n_mult = 256; + int32_t n_head = 32; + int32_t n_layer = 32; + int32_t n_rot = 64; + int32_t f16 = 1; +}; + + +// quantize a model +bool llama_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype) { + ggml_type type = GGML_TYPE_Q4_1; + + switch (itype) { + case 2: type = GGML_TYPE_Q4_0; break; + case 3: type = GGML_TYPE_Q4_1; break; + default: fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); return 1; + }; + + if (type != GGML_TYPE_Q4_0 && type != GGML_TYPE_Q4_1) { + fprintf(stderr, "%s: invalid quantization type %d\n", __func__, type); + return false; + } + + gpt_vocab vocab; + + printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); + + auto finp = std::ifstream(fname_inp, std::ios::binary); + if (!finp) { + fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str()); + return false; + } + + auto fout = std::ofstream(fname_out, std::ios::binary); + if (!fout) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str()); + return false; + } + + // verify magic + { + uint32_t magic; + finp.read((char *) &magic, sizeof(magic)); + if (magic != 0x67676d6c) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); + return false; + } + + fout.write((char *) &magic, sizeof(magic)); + } + + llama_hparams hparams; + + // load hparams + { + finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + //finp.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx)); + finp.read((char *) &hparams.n_embd, sizeof(hparams.n_embd)); + finp.read((char *) &hparams.n_mult, sizeof(hparams.n_mult)); + finp.read((char *) &hparams.n_head, sizeof(hparams.n_head)); + finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); + finp.read((char *) &hparams.n_rot, sizeof(hparams.n_rot)); + finp.read((char *) &hparams.f16, sizeof(hparams.f16)); + + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); + printf("%s: n_embd = %d\n", __func__, hparams.n_embd); + printf("%s: n_mult = %d\n", __func__, hparams.n_mult); + printf("%s: n_head = %d\n", __func__, hparams.n_head); + printf("%s: n_layer = %d\n", __func__, hparams.n_layer); + printf("%s: f16 = %d\n", __func__, hparams.f16); + + fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + //fout.write((char *) &hparams.n_ctx, sizeof(hparams.n_ctx)); + fout.write((char *) &hparams.n_embd, sizeof(hparams.n_embd)); + fout.write((char *) &hparams.n_mult, sizeof(hparams.n_mult)); + fout.write((char *) &hparams.n_head, sizeof(hparams.n_head)); + fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer)); + fout.write((char *) &hparams.n_rot, sizeof(hparams.n_rot)); + fout.write((char *) &itype, sizeof(hparams.f16)); + } + + // load vocab + { + const int32_t n_vocab = hparams.n_vocab; + + if (n_vocab != hparams.n_vocab) { + fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", + __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab); + return false; + } + + std::string word; + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + finp.read ((char *) &len, sizeof(len)); + fout.write((char *) &len, sizeof(len)); + + word.resize(len); + finp.read ((char *) word.data(), len); + fout.write((char *) word.data(), len); + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + } + + // load weights + { + size_t total_size_org = 0; + size_t total_size_new = 0; + + std::vector work; + + std::vector data_u8; + std::vector data_f16; + std::vector data_f32; + + std::vector hist_all(1 << 4, 0); + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; + + finp.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + finp.read(reinterpret_cast(&length), sizeof(length)); + finp.read(reinterpret_cast(&ftype), sizeof(ftype)); + + if (finp.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + finp.read (reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(length, 0); + finp.read (&name[0], length); + + { + static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; + printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]); + } + + // regexes of tensor names to be quantized + const std::vector k_names = { + ".*weight", + }; + + bool quantize = false; + for (const auto & s : k_names) { + if (std::regex_match(name, std::regex(s))) { + quantize = true; + break; + } + } + + // quantize only 2D tensors + quantize &= (n_dims == 2); + + if (quantize) { + if (ftype != 0 && ftype != 1) { + fprintf(stderr, "%s: unsupported ftype %d for integer quantization\n", __func__, ftype); + return false; + } + + if (ftype == 1) { + data_f16.resize(nelements); + finp.read(reinterpret_cast(data_f16.data()), nelements * sizeof(ggml_fp16_t)); + data_f32.resize(nelements); + for (int i = 0; i < nelements; ++i) { + data_f32[i] = ggml_fp16_to_fp32(data_f16[i]); + } + } else { + data_f32.resize(nelements); + finp.read(reinterpret_cast(data_f32.data()), nelements * sizeof(float)); + } + + ftype = itype; + } else { + const int bpe = (ftype == 0) ? sizeof(float) : sizeof(uint16_t); + + data_u8.resize(nelements*bpe); + finp.read(reinterpret_cast(data_u8.data()), nelements * bpe); + } + + fout.write(reinterpret_cast(&n_dims), sizeof(n_dims)); + fout.write(reinterpret_cast(&length), sizeof(length)); + fout.write(reinterpret_cast(&ftype), sizeof(ftype)); + for (int i = 0; i < n_dims; ++i) { + fout.write(reinterpret_cast(&ne[i]), sizeof(ne[i])); + } + fout.write(&name[0], length); + + if (quantize) { + printf("quantizing .. "); + work.resize(nelements); // for quantization + + size_t cur_size = 0; + std::vector hist_cur(1 << 4, 0); + + switch (type) { + case GGML_TYPE_Q4_0: + { + cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], QK, hist_cur.data()); + } break; + case GGML_TYPE_Q4_1: + { + cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], QK, hist_cur.data()); + } break; + default: + { + fprintf(stderr, "%s: unsupported quantization type %d\n", __func__, type); + return false; + } + } + + fout.write(reinterpret_cast(work.data()), cur_size); + total_size_new += cur_size; + + printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0); + for (int i = 0; i < hist_cur.size(); ++i) { + hist_all[i] += hist_cur[i]; + } + + for (int i = 0; i < hist_cur.size(); ++i) { + printf("%5.3f ", hist_cur[i] / (float)nelements); + } + printf("\n"); + } else { + printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0); + fout.write(reinterpret_cast(data_u8.data()), data_u8.size()); + total_size_new += data_u8.size(); + } + + total_size_org += nelements * sizeof(float); + } + + printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); + printf("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0); + + { + int64_t sum_all = 0; + for (int i = 0; i < hist_all.size(); ++i) { + sum_all += hist_all[i]; + } + + printf("%s: hist: ", __func__); + for (int i = 0; i < hist_all.size(); ++i) { + printf("%5.3f ", hist_all[i] / (float)sum_all); + } + printf("\n"); + } + } + + finp.close(); + fout.close(); + + return true; +} + +// usage: +// ./llama-quantize models/llama/ggml-model.bin models/llama/ggml-model-quant.bin type +// +int main(int argc, char ** argv) { + if (argc != 4) { + fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]); + fprintf(stderr, " type = 2 - q4_0\n"); + fprintf(stderr, " type = 3 - q4_1\n"); + return 1; + } + + const std::string fname_inp = argv[1]; + const std::string fname_out = argv[2]; + + const int itype = atoi(argv[3]); + + const int64_t t_main_start_us = ggml_time_us(); + + int64_t t_quantize_us = 0; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!llama_model_quantize(fname_inp, fname_out, itype)) { + fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); + return 1; + } + + t_quantize_us = ggml_time_us() - t_start_us; + } + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + printf("\n"); + printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + } + + return 0; +} diff --git a/examples/utils.cpp b/examples/utils.cpp index 402a1fd..fe7a489 100644 --- a/examples/utils.cpp +++ b/examples/utils.cpp @@ -230,6 +230,38 @@ std::vector gpt_tokenize(const gpt_vocab & vocab, const std::stri return tokens; } +std::vector llama_tokenize(const gpt_vocab & vocab, const std::string & text, bool bos) { + std::vector res; + + if (bos) { + res.push_back(1); // TODO: replace with vocab.bos + } + + // find the longest token that matches the text + int pos = 0; + while (true) { + int l = 0; + int t = 0; + for (const auto & kv : vocab.id_to_token) { + if (kv.second.size() < l) continue; + if (kv.second.size() > text.size() - pos) continue; + if (text.substr(pos, kv.second.size()) == kv.second) { + l = kv.second.size(); + t = kv.first; + } + } + + if (l == 0) { + break; + } + + res.push_back(t); + pos += l; + } + + return res; +} + bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) { printf("%s: loading vocab from '%s'\n", __func__, fname.c_str()); diff --git a/examples/utils.h b/examples/utils.h index f7d0dbc..a5cfbd8 100644 --- a/examples/utils.h +++ b/examples/utils.h @@ -63,6 +63,10 @@ std::map json_parse(const std::string & fname); // std::vector gpt_tokenize(const gpt_vocab & vocab, const std::string & text); +// TODO: this is probably wrong, but I cannot figure out how this tokenizer works .. +// ref: https://github.com/google/sentencepiece +std::vector llama_tokenize(const gpt_vocab & vocab, const std::string & text, bool bos); + // load the tokens from encoder.json bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab); diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index cdd8553..7ce655c 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -228,6 +228,7 @@ enum ggml_op { GGML_OP_STEP, GGML_OP_RELU, GGML_OP_GELU, + GGML_OP_SILU, GGML_OP_NORM, // normalize GGML_OP_MUL_MAT, @@ -471,6 +472,10 @@ struct ggml_tensor * ggml_gelu( struct ggml_context * ctx, struct ggml_tensor * a); +struct ggml_tensor * ggml_silu( + struct ggml_context * ctx, + struct ggml_tensor * a); + // normalize along rows // TODO: eps is hardcoded to 1e-5 for now struct ggml_tensor * ggml_norm( diff --git a/src/ggml.c b/src/ggml.c index 2c60942..ee3b0af 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -83,6 +83,7 @@ typedef void* thread_ret_t; /*#define GGML_PERF*/ #define GGML_DEBUG 0 #define GGML_GELU_FP16 +#define GGML_SILU_FP16 #define GGML_SOFT_MAX_UNROLL 4 #define GGML_VEC_DOT_UNROLL 2 @@ -246,6 +247,9 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { // precomputed gelu table for f16 (128 KB) static ggml_fp16_t table_gelu_f16[1 << 16]; +// precomputed silu table for f16 (128 KB) +static ggml_fp16_t table_silu_f16[1 << 16]; + // precomputed exp table for f16 (128 KB) static ggml_fp16_t table_exp_f16[1 << 16]; @@ -1741,6 +1745,35 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { } #endif +// Sigmoid Linear Unit (SiLU) function +inline static float ggml_silu_f32(float x) { + return x/(1.0 + exp(-x)); +} + +inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + y[i] = table_silu_f16[i16[i]]; + } +} + +#ifdef GGML_SILU_FP16 +inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(table_silu_f16[t]); + } +} +#else +inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_silu_f32(x[i]); + } +} +#endif + inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE ggml_float sum = 0.0; @@ -1839,6 +1872,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "STEP", "RELU", "GELU", + "SILU", "NORM", "MUL_MAT", @@ -1860,7 +1894,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "FLASH_FF", }; -static_assert(GGML_OP_COUNT == 33, "GGML_OP_COUNT != 33"); +static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1881,6 +1915,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "step(x)", "relu(x)", "gelu(x)", + "silu(x)", "norm(x)", "X*Y", @@ -1902,7 +1937,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "flash_ff(x)", }; -static_assert(GGML_OP_COUNT == 33, "GGML_OP_COUNT != 33"); +static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34"); // // ggml object @@ -2144,7 +2179,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { static bool is_first_call = true; if (is_first_call) { - // initialize GELU, EXP and F32 tables + // initialize GELU, SILU and EXP F32 tables { const uint64_t t_start = ggml_time_us(); UNUSED(t_start); @@ -2154,12 +2189,13 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { memcpy(&ii, &ui, sizeof(ii)); const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii); table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); + table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f)); table_exp_f16[i] = GGML_FP32_TO_FP16(exp(f)); } const uint64_t t_end = ggml_time_us(); UNUSED(t_end); - GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); } // initialize g_state @@ -3318,6 +3354,40 @@ struct ggml_tensor * ggml_gelu_inplace( return ggml_gelu_impl(ctx, a, true); } +// ggml_silu + +struct ggml_tensor * ggml_silu_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SILU; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_silu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_silu_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_silu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_silu_impl(ctx, a, true); +} + // ggml_norm struct ggml_tensor * ggml_norm_impl( @@ -4991,6 +5061,72 @@ static void ggml_compute_forward_gelu( //printf("XXXXXXXX gelu\n"); } +// ggml_compute_forward_silu + +static void ggml_compute_forward_silu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_silu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_silu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_silu_f32(params, src0, dst); + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + + // ggml_compute_forward_norm static void ggml_compute_forward_norm_f32( @@ -8183,6 +8319,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_gelu(params, tensor->src0, tensor); } break; + case GGML_OP_SILU: + { + ggml_compute_forward_silu(params, tensor->src0, tensor); + } break; case GGML_OP_NORM: { ggml_compute_forward_norm(params, tensor->src0, tensor); @@ -8421,6 +8561,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_SILU: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_NORM: { GGML_ASSERT(false); // TODO: not implemented @@ -8846,6 +8990,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { node->n_tasks = n_threads; } break; + case GGML_OP_SILU: + { + node->n_tasks = n_threads; + } break; case GGML_OP_NORM: { node->n_tasks = n_threads;