diff --git a/README.md b/README.md index afa0e7e..02b0c09 100644 --- a/README.md +++ b/README.md @@ -17,12 +17,11 @@ The main goal is to run the model using 4-bit quantization on a MacBook. This was hacked in an evening - I have no idea if it works correctly. -So far, I've tested just the 7B model. -Here is a typical run: +Here is a typical run using LLaMA-7B: ```java -make -j && ./main -m ../LLaMA-4bit/7B/ggml-model-q4_0.bin -p "Building a website can be done in 10 simple steps:" -t 8 -n 512 -I llama.cpp build info: +make -j && ./main -m ./models/7B/ggml-model-q4_0.bin -p "Building a website can be done in 10 simple steps:" -t 8 -n 512 +I llama.cpp build info: I UNAME_S: Darwin I UNAME_P: arm I UNAME_M: arm64 @@ -34,7 +33,7 @@ I CXX: Apple clang version 14.0.0 (clang-1400.0.29.202) make: Nothing to be done for `default'. main: seed = 1678486056 -llama_model_load: loading model from '../LLaMA-4bit/7B/ggml-model-q4_0.bin' - please wait ... +llama_model_load: loading model from './models/7B/ggml-model-q4_0.bin' - please wait ... llama_model_load: n_vocab = 32000 llama_model_load: n_ctx = 512 llama_model_load: n_embd = 4096 @@ -110,6 +109,8 @@ https://user-images.githubusercontent.com/1991296/224442907-7693d4be-acaa-4e01-8 ## Usage +Here are the step for the LLaMA-7B model: + ```bash # build this repo git clone https://github.com/ggerganov/llama.cpp @@ -133,9 +134,40 @@ python3 convert-pth-to-ggml.py models/7B/ 1 ./main -m ./models/7B/ggml-model-q4_0.bin -t 8 -n 128 ``` +For the bigger models, there are a few extra quantization steps. For example, for LLaMA-13B, converting to FP16 format +will create 2 ggml files, instead of one: + +```bash +ggml-model-f16.bin +ggml-model-f16.bin.1 +``` + +You need to quantize each of them separately like this: + +```bash +./quantize ./models/13B/ggml-model-f16.bin ./models/13B/ggml-model-q4_0.bin 2 +./quantize ./models/13B/ggml-model-f16.bin.1 ./models/13B/ggml-model-q4_0.bin.1 2 +``` + +Everything else is the same. Simply run: + +```bash +./main -m ./models/13B/ggml-model-q4_0.bin -t 8 -n 128 +``` + +The number of files generated for each model is as follows: + +``` +7B -> 1 file +13B -> 2 files +33B -> 4 files +65B -> 8 files +``` + +When running the larger models, make sure you have enough disk space to store all the intermediate files. + ## Limitations -- Currently, only LLaMA-7B is supported since I haven't figured out how to merge the tensors of the bigger models. However, in theory, you should be able to run 65B on a 64GB MacBook - Not sure if my tokenizer is correct. There are a few places where we might have a mistake: - https://github.com/ggerganov/llama.cpp/blob/26c084662903ddaca19bef982831bfb0856e8257/convert-pth-to-ggml.py#L79-L87 - https://github.com/ggerganov/llama.cpp/blob/26c084662903ddaca19bef982831bfb0856e8257/utils.h#L65-L69 diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py index bd0a9d0..fc217c7 100644 --- a/convert-pth-to-ggml.py +++ b/convert-pth-to-ggml.py @@ -33,12 +33,23 @@ if len(sys.argv) < 3: # output in the same directory as the model dir_model = sys.argv[1] -fname_out = sys.argv[1] + "/ggml-model.bin" fname_hparams = sys.argv[1] + "/params.json" -fname_model = sys.argv[1] + "/consolidated.00.pth" fname_tokenizer = sys.argv[1] + "/../tokenizer.model" +def get_n_parts(dim): + if dim == 4096: + return 1 + elif dim == 5120: + return 2 + elif dim == 6656: + return 4 + elif dim == 8192: + return 8 + else: + print("Invalid dim: " + str(dim)) + sys.exit(1) + # possible data types # ftype == 0 -> float32 # ftype == 1 -> float16 @@ -61,76 +72,91 @@ tokenizer = SentencePieceProcessor(fname_tokenizer) hparams.update({"vocab_size": tokenizer.vocab_size()}) +n_parts = get_n_parts(hparams["dim"]) + print(hparams) +print('n_parts = ', n_parts) -model = torch.load(fname_model, map_location="cpu") - -fout = open(fname_out, "wb") - -fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex -fout.write(struct.pack("i", hparams["vocab_size"])) -fout.write(struct.pack("i", hparams["dim"])) -fout.write(struct.pack("i", hparams["multiple_of"])) -fout.write(struct.pack("i", hparams["n_heads"])) -fout.write(struct.pack("i", hparams["n_layers"])) -fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete) -fout.write(struct.pack("i", ftype)) - -# Is this correct?? -for i in range(32000): - # TODO: this is probably wrong - not sure how this tokenizer works - text = tokenizer.decode([29889, i]).encode('utf-8') - # remove the first byte (it's always '.') - text = text[1:] - fout.write(struct.pack("i", len(text))) - fout.write(text) - -for k, v in model.items(): - name = k - shape = v.shape - - # skip layers.X.attention.inner_attention.rope.freqs - if name[-5:] == "freqs": - continue - - print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype) - - #data = tf.train.load_variable(dir_model, name).squeeze() - data = v.numpy().squeeze() - n_dims = len(data.shape); - - # for efficiency - transpose some matrices - # "model/h.*/attn/c_attn/w" - # "model/h.*/attn/c_proj/w" - # "model/h.*/mlp/c_fc/w" - # "model/h.*/mlp/c_proj/w" - #if name[-14:] == "/attn/c_attn/w" or \ - # name[-14:] == "/attn/c_proj/w" or \ - # name[-11:] == "/mlp/c_fc/w" or \ - # name[-13:] == "/mlp/c_proj/w": - # print(" Transposing") - # data = data.transpose() - - dshape = data.shape - - # default type is fp16 - ftype_cur = 1 - if ftype == 0 or n_dims == 1: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - - # header - str = name.encode('utf-8') - fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) - for i in range(n_dims): - fout.write(struct.pack("i", dshape[n_dims - 1 - i])) - fout.write(str); - - # data - data.tofile(fout) - -fout.close() - -print("Done. Output file: " + fname_out) -print("") +for p in range(n_parts): + print('Processing part ', p) + + #fname_model = sys.argv[1] + "/consolidated.00.pth" + fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth" + fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + if (p > 0): + fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p) + + model = torch.load(fname_model, map_location="cpu") + + fout = open(fname_out, "wb") + + fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex + fout.write(struct.pack("i", hparams["vocab_size"])) + fout.write(struct.pack("i", hparams["dim"])) + fout.write(struct.pack("i", hparams["multiple_of"])) + fout.write(struct.pack("i", hparams["n_heads"])) + fout.write(struct.pack("i", hparams["n_layers"])) + fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete) + fout.write(struct.pack("i", ftype)) + + # Is this correct?? + for i in range(32000): + # TODO: this is probably wrong - not sure how this tokenizer works + text = tokenizer.decode([29889, i]).encode('utf-8') + # remove the first byte (it's always '.') + text = text[1:] + fout.write(struct.pack("i", len(text))) + fout.write(text) + + for k, v in model.items(): + name = k + shape = v.shape + + # skip layers.X.attention.inner_attention.rope.freqs + if name[-5:] == "freqs": + continue + + print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype) + + #data = tf.train.load_variable(dir_model, name).squeeze() + data = v.numpy().squeeze() + n_dims = len(data.shape); + + # for efficiency - transpose some matrices + # "model/h.*/attn/c_attn/w" + # "model/h.*/attn/c_proj/w" + # "model/h.*/mlp/c_fc/w" + # "model/h.*/mlp/c_proj/w" + #if name[-14:] == "/attn/c_attn/w" or \ + # name[-14:] == "/attn/c_proj/w" or \ + # name[-11:] == "/mlp/c_fc/w" or \ + # name[-13:] == "/mlp/c_proj/w": + # print(" Transposing") + # data = data.transpose() + + dshape = data.shape + + # default type is fp16 + ftype_cur = 1 + if ftype == 0 or n_dims == 1: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + + # header + sname = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur)) + for i in range(n_dims): + fout.write(struct.pack("i", dshape[n_dims - 1 - i])) + fout.write(sname); + + # data + data.tofile(fout) + + # I hope this deallocates the memory .. + model = None + + fout.close() + + print("Done. Output file: " + fname_out + ", (part ", p, ")") + print("") diff --git a/ggml.c b/ggml.c index ee3b0af..bb714e2 100644 --- a/ggml.c +++ b/ggml.c @@ -366,9 +366,10 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { assert(k % QK == 0); const int nb = k / QK; + const size_t bs = sizeof(float) + QK/2; - float * restrict pd = (float *) (y); - uint8_t * restrict pb = (uint8_t *) (pd + nb); + uint8_t * restrict pd = (uint8_t *) (y + 0*bs); + uint8_t * restrict pb = (uint8_t *) (y + 0*bs + sizeof(float)); uint8_t pp[QK/2]; @@ -395,7 +396,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { const float d = amax / ((1 << 3) - 1); const float id = d ? 1.0/d : 0.0; - pd[i] = d; + *(float *)pd = d; + pd += bs; for (int l = 0; l < 8; l++) { const float32x4_t v = vmulq_n_f32(srcv[l], id); @@ -406,7 +408,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); } - memcpy(pb + i*16, pp, sizeof(pp)); + memcpy(pb, pp, sizeof(pp)); + pb += bs; } #else #error "not implemented for QK" @@ -434,7 +437,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { const float d = amax / ((1 << 3) - 1); const float id = d ? 1.0/d : 0.0; - pd[i] = d; + *(float *)pd = d; + pd += bs; for (int l = 0; l < 8; l++) { const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id)); @@ -445,7 +449,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4); } - memcpy(pb + i*16, pp, sizeof(pp)); + memcpy(pb, pp, sizeof(pp)); + pb += bs; } #else #error "not implemented for QK" @@ -463,7 +468,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { const float d = amax / ((1 << 3) - 1); const float id = d ? 1.0f/d : 0.0f; - pd[i] = d; + *(float *)pd = d; + pd += bs; for (int l = 0; l < QK; l += 2) { const float v0 = x[i*QK + l + 0]*id; @@ -478,7 +484,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { pp[l/2] = vi0 | (vi1 << 4); } - memcpy(pb + i*QK/2, pp, sizeof(pp)); + memcpy(pb, pp, sizeof(pp)); + pb += bs; } #endif } @@ -535,15 +542,16 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { assert(k % QK == 0); const int nb = k / QK; + const size_t bs = sizeof(float) + QK/2; - const float * restrict pd = (const float *) (x); - const uint8_t * restrict pb = (const uint8_t *) (pd + nb); + const uint8_t * restrict pd = (const uint8_t *) (x + 0*bs); + const uint8_t * restrict pb = (const uint8_t *) (x + 0*bs + sizeof(float)); // scalar for (int i = 0; i < nb; i++) { - const float d = pd[i]; + const float d = *(const float *) (pd + i*bs); - const uint8_t * restrict pp = pb + i*QK/2; + const uint8_t * restrict pp = pb + i*bs; for (int l = 0; l < QK; l += 2) { const uint8_t vi = pp[l/2]; @@ -554,6 +562,8 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { const float v0 = (vi0 - 8)*d; const float v1 = (vi1 - 8)*d; + //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1); + y[i*QK + l + 0] = v0; y[i*QK + l + 1] = v1; @@ -1179,11 +1189,13 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void assert(n % QK == 0); assert(nb % 2 == 0); - const float * restrict pd0 = (const float *) x; - const float * restrict pd1 = (const float *) y; + const size_t bs = sizeof(float) + QK/2; - const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb); - const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb); + const uint8_t * restrict pd0 = (const uint8_t *) (x + 0*bs); + const uint8_t * restrict pd1 = (const uint8_t *) (y + 0*bs); + + const uint8_t * restrict pb0 = (const uint8_t *) (x + 0*bs + sizeof(float)); + const uint8_t * restrict pb1 = (const uint8_t *) (y + 0*bs + sizeof(float)); float sumf = 0.0; @@ -1193,23 +1205,23 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void float sum1 = 0.0f; for (int i = 0; i < nb; i += 2) { - const float d0_0 = pd0[i + 0]; - const float d1_0 = pd1[i + 0]; - const float d0_1 = pd0[i + 1]; - const float d1_1 = pd1[i + 1]; + const float d0_0 = *(const float *) (pd0 + i*bs); + const float d1_0 = *(const float *) (pd1 + i*bs); + const float d0_1 = *(const float *) (pd0 + (i + 1)*bs); + const float d1_1 = *(const float *) (pd1 + (i + 1)*bs); //printf("d0_0: %f, d1_0: %f, d0_1: %f, d1_1: %f\n", d0_0, d1_0, d0_1, d1_1); - const uint8_t * restrict p0 = pb0 + i*16; - const uint8_t * restrict p1 = pb1 + i*16; + const uint8_t * restrict p0 = pb0 + i*bs; + const uint8_t * restrict p1 = pb1 + i*bs; const uint8x16_t m4b = vdupq_n_u8(0xf); const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(p0); const uint8x16_t v1_0 = vld1q_u8(p1); - const uint8x16_t v0_1 = vld1q_u8(p0 + 16); - const uint8x16_t v1_1 = vld1q_u8(p1 + 16); + const uint8x16_t v0_1 = vld1q_u8(p0 + bs); + const uint8x16_t v1_1 = vld1q_u8(p1 + bs); // 4-bit -> 8-bit const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b)); @@ -1280,21 +1292,21 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void float sum1 = 0.0f; for (int i = 0; i < nb; i += 2) { - const float d0_0 = pd0[i + 0]; - const float d0_1 = pd0[i + 1]; - const float d1_0 = pd1[i + 0]; - const float d1_1 = pd1[i + 1]; + const float d0_0 = *(const float *) (pd0 + i*bs); + const float d1_0 = *(const float *) (pd1 + i*bs); + const float d0_1 = *(const float *) (pd0 + (i + 1)*bs); + const float d1_1 = *(const float *) (pd1 + (i + 1)*bs); - const uint8_t * restrict p0 = pb0 + i*16; - const uint8_t * restrict p1 = pb1 + i*16; + const uint8_t * restrict p0 = pb0 + i*bs; + const uint8_t * restrict p1 = pb1 + i*bs; const v128_t m4b = wasm_u8x16_splat(0xf); const v128_t s8b = wasm_i8x16_splat(0x8); const v128_t v0_0 = wasm_v128_load(p0); - const v128_t v0_1 = wasm_v128_load(p0 + 16); + const v128_t v0_1 = wasm_v128_load(p0 + bs); const v128_t v1_0 = wasm_v128_load(p1); - const v128_t v1_1 = wasm_v128_load(p1 + 16); + const v128_t v1_1 = wasm_v128_load(p1 + bs); // 4-bit -> 8-bit const v128_t v0_0l = wasm_v128_and(v0_0, m4b); @@ -1363,11 +1375,11 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void #else // scalar for (int i = 0; i < nb; i++) { - const float d0 = pd0[i]; - const float d1 = pd1[i]; + const float d0 = *(const float *) (pd0 + i*bs); + const float d1 = *(const float *) (pd1 + i*bs); - const uint8_t * restrict p0 = pb0 + i*QK/2; - const uint8_t * restrict p1 = pb1 + i*QK/2; + const uint8_t * restrict p0 = pb0 + i*bs; + const uint8_t * restrict p1 = pb1 + i*bs; for (int j = 0; j < QK/2; j++) { const uint8_t v0 = p0[j]; @@ -1552,16 +1564,17 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res assert(n % QK == 0); const int nb = n / QK; + const size_t bs = sizeof(float) + QK/2; - const float * restrict pd = (const float *) (x); - const uint8_t * restrict pb = (const uint8_t *) (pd + nb); + const uint8_t * restrict pd = (const uint8_t *) (x + 0*bs); + const uint8_t * restrict pb = (const uint8_t *) (x + 0*bs + sizeof(float)); #if __ARM_NEON #if QK == 32 for (int i = 0; i < nb; ++i) { - const float d0 = pd[i]*v; + const float d0 = v*(*(const float *) (pd + i*bs)); - const uint8_t * restrict pp = pb + i*16; + const uint8_t * restrict pp = pb + i*bs; const uint8x8_t m4b = vdup_n_u8(0xf); const int8x8_t s8b = vdup_n_s8(0x8); @@ -1615,9 +1628,9 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res #else // scalar for (int i = 0; i < nb; i++) { - const float d = pd[i]; + const float d = *(const float *) (pd + i*bs); - const uint8_t * restrict pp = pb + i*QK/2; + const uint8_t * restrict pp = pb + i*bs; for (int l = 0; l < QK; l += 2) { const uint8_t vi = pp[l/2]; diff --git a/main.cpp b/main.cpp index eca7140..d28fc91 100644 --- a/main.cpp +++ b/main.cpp @@ -11,6 +11,14 @@ #include #include +// determine number of model parts based on the dimension +static const std::map LLAMA_N_PARTS = { + { 4096, 1 }, + { 5120, 2 }, + { 6656, 4 }, + { 8192, 8 }, +}; + // default hparams (LLaMA 7B) struct llama_hparams { int32_t n_vocab = 32000; @@ -82,6 +90,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab } int n_ff = 0; + int n_parts = 0; // load hparams { @@ -99,6 +108,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab hparams.n_ctx = n_ctx; n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; + n_parts = LLAMA_N_PARTS.at(hparams.n_embd); printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); @@ -109,6 +119,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab printf("%s: n_rot = %d\n", __func__, hparams.n_rot); printf("%s: f16 = %d\n", __func__, hparams.f16); printf("%s: n_ff = %d\n", __func__, n_ff); + printf("%s: n_parts = %d\n", __func__, n_parts); } // load vocab @@ -220,7 +231,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab model.layers.resize(n_layer); - model.tok_embeddings = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); + model.tok_embeddings = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); model.norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); model.output = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); @@ -234,14 +245,14 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab for (int i = 0; i < n_layer; ++i) { auto & layer = model.layers[i]; - layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); - layer.wq = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); - layer.wk = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); - layer.wv = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); - layer.wo = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.wq = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.wk = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.wv = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.wo = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); - layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.w1 = ggml_new_tensor_2d(ctx, wtype, n_embd, n_ff); layer.w2 = ggml_new_tensor_2d(ctx, wtype, n_ff, n_embd); @@ -282,94 +293,208 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab printf("%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem); } - // load weights - { - int n_tensors = 0; - size_t total_size = 0; + const size_t file_offset = fin.tellg(); - printf("%s: ", __func__); + fin.close(); - while (true) { - int32_t n_dims; - int32_t length; - int32_t ftype; + std::vector tmp; - fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); - fin.read(reinterpret_cast(&length), sizeof(length)); - fin.read(reinterpret_cast(&ftype), sizeof(ftype)); + for (int i = 0; i < n_parts; ++i) { + const int part_id = i; + //const int part_id = n_parts - i - 1; - if (fin.eof()) { - break; - } + std::string fname_part = fname; + if (i > 0) { + fname_part += "." + std::to_string(i); + } - int32_t nelements = 1; - int32_t ne[2] = { 1, 1 }; - for (int i = 0; i < n_dims; ++i) { - fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); - nelements *= ne[i]; - } + printf("%s: loading model part %d/%d from '%s'\n", __func__, i+1, n_parts, fname_part.c_str()); - std::string name(length, 0); - fin.read(&name[0], length); + fin = std::ifstream(fname_part, std::ios::binary); + fin.seekg(file_offset); - if (model.tensors.find(name.data()) == model.tensors.end()) { - fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); - return false; - } + // load weights + { + int n_tensors = 0; + size_t total_size = 0; - auto tensor = model.tensors[name.data()]; - if (ggml_nelements(tensor) != nelements) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); - return false; - } + printf("%s: ", __func__); - if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { - fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", - __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]); - return false; - } + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; - if (0) { - static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; - printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ftype_str[ftype], ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor)); - } + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ftype), sizeof(ftype)); + + if (fin.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } - size_t bpe = 0; + std::string name(length, 0); + fin.read(&name[0], length); - switch (ftype) { - case 0: bpe = ggml_type_size(GGML_TYPE_F32); break; - case 1: bpe = ggml_type_size(GGML_TYPE_F16); break; - case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break; - case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break; - default: - { - fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); + if (model.tensors.find(name.data()) == model.tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + // split_type = 0: split by columns + // split_type = 1: split by rows + int split_type = 0; + + // split_type = 0: + // regex: + // - tok_embeddings.* + // - layers.*.attention.wo.weight + // - layers.*.feed_forward.w2.weight + + // split_type = 1: + // regex: + // - output.* + // - layers.*.attention.wq.weight + // - layers.*.attention.wk.weight + // - layers.*.attention.wv.weight + // - layers.*.feed_forward.w1.weight + // - layers.*.feed_forward.w3.weight + if (name.find("tok_embeddings") != std::string::npos) { + split_type = 0; + } else if (name.find("layers") != std::string::npos) { + if (name.find("attention.wo.weight") != std::string::npos) { + split_type = 0; + } else if (name.find("feed_forward.w2.weight") != std::string::npos) { + split_type = 0; + } else { + split_type = 1; + } + } else if (name.find("output") != std::string::npos) { + split_type = 1; + } + + auto tensor = model.tensors[name.data()]; + + if (n_dims == 1) { + if (ggml_nelements(tensor) != nelements) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + return false; + } + } else { + if (ggml_nelements(tensor)/n_parts != nelements) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + return false; + } + } + + if (n_dims == 1) { + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", + __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]); + return false; + } + } else { + if (split_type == 0) { + if (tensor->ne[0]/n_parts != ne[0] || tensor->ne[1] != ne[1]) { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", + __func__, name.data(), tensor->ne[0]/n_parts, tensor->ne[1], ne[0], ne[1]); + return false; + } + } else { + if (tensor->ne[0] != ne[0] || tensor->ne[1]/n_parts != ne[1]) { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", + __func__, name.data(), tensor->ne[0], tensor->ne[1]/n_parts, ne[0], ne[1]); return false; } - }; + } + } - if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", - __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); - return false; - } + if (0) { + static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; + printf("%24s - [%5d, %5d], type = %6s, split = %d\n", name.data(), ne[0], ne[1], ftype_str[ftype], split_type); + } + + size_t bpe = 0; + + switch (ftype) { + case 0: bpe = ggml_type_size(GGML_TYPE_F32); break; + case 1: bpe = ggml_type_size(GGML_TYPE_F16); break; + case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break; + case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break; + default: + { + fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); + return false; + } + }; + + if (n_dims == 1 || n_parts == 1) { + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return false; + } + + if (part_id == 0) { + fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); + } else { + fin.seekg(ggml_nbytes(tensor), std::ios::cur); + } + + total_size += ggml_nbytes(tensor); + } else { + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)/n_parts) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor)/n_parts, nelements*bpe); + return false; + } + + if (split_type == 0) { + const int np0 = ne[0]; + + const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type); + assert(row_size == tensor->nb[1]); + + for (int i1 = 0; i1 < ne[1]; ++i1) { + const size_t offset_row = i1*row_size; + const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type); + fin.read(reinterpret_cast(tensor->data) + offset, row_size/n_parts); + } + } else { + const int np1 = ne[1]; - fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); + const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type); - //printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); - total_size += ggml_nbytes(tensor); - if (++n_tensors % 8 == 0) { - printf("."); - fflush(stdout); + for (int i1 = 0; i1 < ne[1]; ++i1) { + const size_t offset_row = (i1 + part_id*np1)*row_size; + fin.read(reinterpret_cast(tensor->data) + offset_row, row_size); + } + } + + total_size += ggml_nbytes(tensor)/n_parts; + } + + //printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); + if (++n_tensors % 8 == 0) { + printf("."); + fflush(stdout); + } } - } - printf(" done\n"); + printf(" done\n"); - printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors); - } + printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors); + } - fin.close(); + fin.close(); + } return true; } diff --git a/utils.cpp b/utils.cpp index 6bd1fc0..abb3475 100644 --- a/utils.cpp +++ b/utils.cpp @@ -448,7 +448,8 @@ gpt_vocab::id llama_sample_top_p( size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t * hist) { const int nb = k / qk; - const size_t row_size = nb*(sizeof(float) + sizeof(uint8_t)*qk/2); + const size_t bs = (sizeof(float) + sizeof(uint8_t)*qk/2); + const size_t row_size = nb*bs; assert(k % qk == 0); @@ -457,8 +458,8 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t char * pdst = (char *) dst; for (int j = 0; j < n; j += k) { - float * pd = (float *) (pdst + (j/k)*row_size); - uint8_t * pb = (uint8_t *) (pd + nb); + uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs); + uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float)); for (int i = 0; i < nb; i++) { float amax = 0.0f; // absolute max @@ -472,7 +473,8 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t const float d = amax / ((1 << 3) - 1); const float id = d ? 1.0f/d : 0.0f; - pd[i] = d; + *(float *) pd = d; + pd += bs; for (int l = 0; l < qk; l += 2) { const float v0 = (src[j + i*qk + l + 0])*id; @@ -490,7 +492,8 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t pp[l/2] = vi0 | (vi1 << 4); } - memcpy(pb + i*qk/2, pp, sizeof(pp)); + memcpy(pb, pp, sizeof(pp)); + pb += bs; } } }