Support all LLaMA models + change Q4_0 quantization storage

pull/16/head
Georgi Gerganov 1 year ago
parent 5f2f970d51
commit 007a8f6f45
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -17,12 +17,11 @@ The main goal is to run the model using 4-bit quantization on a MacBook.
This was hacked in an evening - I have no idea if it works correctly.
So far, I've tested just the 7B model.
Here is a typical run:
Here is a typical run using LLaMA-7B:
```java
make -j && ./main -m ../LLaMA-4bit/7B/ggml-model-q4_0.bin -p "Building a website can be done in 10 simple steps:" -t 8 -n 512
I llama.cpp build info:
make -j && ./main -m ./models/7B/ggml-model-q4_0.bin -p "Building a website can be done in 10 simple steps:" -t 8 -n 512
I llama.cpp build info:
I UNAME_S: Darwin
I UNAME_P: arm
I UNAME_M: arm64
@ -34,7 +33,7 @@ I CXX: Apple clang version 14.0.0 (clang-1400.0.29.202)
make: Nothing to be done for `default'.
main: seed = 1678486056
llama_model_load: loading model from '../LLaMA-4bit/7B/ggml-model-q4_0.bin' - please wait ...
llama_model_load: loading model from './models/7B/ggml-model-q4_0.bin' - please wait ...
llama_model_load: n_vocab = 32000
llama_model_load: n_ctx = 512
llama_model_load: n_embd = 4096
@ -110,6 +109,8 @@ https://user-images.githubusercontent.com/1991296/224442907-7693d4be-acaa-4e01-8
## Usage
Here are the step for the LLaMA-7B model:
```bash
# build this repo
git clone https://github.com/ggerganov/llama.cpp
@ -133,9 +134,40 @@ python3 convert-pth-to-ggml.py models/7B/ 1
./main -m ./models/7B/ggml-model-q4_0.bin -t 8 -n 128
```
For the bigger models, there are a few extra quantization steps. For example, for LLaMA-13B, converting to FP16 format
will create 2 ggml files, instead of one:
```bash
ggml-model-f16.bin
ggml-model-f16.bin.1
```
You need to quantize each of them separately like this:
```bash
./quantize ./models/13B/ggml-model-f16.bin ./models/13B/ggml-model-q4_0.bin 2
./quantize ./models/13B/ggml-model-f16.bin.1 ./models/13B/ggml-model-q4_0.bin.1 2
```
Everything else is the same. Simply run:
```bash
./main -m ./models/13B/ggml-model-q4_0.bin -t 8 -n 128
```
The number of files generated for each model is as follows:
```
7B -> 1 file
13B -> 2 files
33B -> 4 files
65B -> 8 files
```
When running the larger models, make sure you have enough disk space to store all the intermediate files.
## Limitations
- Currently, only LLaMA-7B is supported since I haven't figured out how to merge the tensors of the bigger models. However, in theory, you should be able to run 65B on a 64GB MacBook
- Not sure if my tokenizer is correct. There are a few places where we might have a mistake:
- https://github.com/ggerganov/llama.cpp/blob/26c084662903ddaca19bef982831bfb0856e8257/convert-pth-to-ggml.py#L79-L87
- https://github.com/ggerganov/llama.cpp/blob/26c084662903ddaca19bef982831bfb0856e8257/utils.h#L65-L69

@ -33,12 +33,23 @@ if len(sys.argv) < 3:
# output in the same directory as the model
dir_model = sys.argv[1]
fname_out = sys.argv[1] + "/ggml-model.bin"
fname_hparams = sys.argv[1] + "/params.json"
fname_model = sys.argv[1] + "/consolidated.00.pth"
fname_tokenizer = sys.argv[1] + "/../tokenizer.model"
def get_n_parts(dim):
if dim == 4096:
return 1
elif dim == 5120:
return 2
elif dim == 6656:
return 4
elif dim == 8192:
return 8
else:
print("Invalid dim: " + str(dim))
sys.exit(1)
# possible data types
# ftype == 0 -> float32
# ftype == 1 -> float16
@ -61,76 +72,91 @@ tokenizer = SentencePieceProcessor(fname_tokenizer)
hparams.update({"vocab_size": tokenizer.vocab_size()})
n_parts = get_n_parts(hparams["dim"])
print(hparams)
print('n_parts = ', n_parts)
model = torch.load(fname_model, map_location="cpu")
fout = open(fname_out, "wb")
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
fout.write(struct.pack("i", hparams["vocab_size"]))
fout.write(struct.pack("i", hparams["dim"]))
fout.write(struct.pack("i", hparams["multiple_of"]))
fout.write(struct.pack("i", hparams["n_heads"]))
fout.write(struct.pack("i", hparams["n_layers"]))
fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete)
fout.write(struct.pack("i", ftype))
# Is this correct??
for i in range(32000):
# TODO: this is probably wrong - not sure how this tokenizer works
text = tokenizer.decode([29889, i]).encode('utf-8')
# remove the first byte (it's always '.')
text = text[1:]
fout.write(struct.pack("i", len(text)))
fout.write(text)
for k, v in model.items():
name = k
shape = v.shape
# skip layers.X.attention.inner_attention.rope.freqs
if name[-5:] == "freqs":
continue
print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)
#data = tf.train.load_variable(dir_model, name).squeeze()
data = v.numpy().squeeze()
n_dims = len(data.shape);
# for efficiency - transpose some matrices
# "model/h.*/attn/c_attn/w"
# "model/h.*/attn/c_proj/w"
# "model/h.*/mlp/c_fc/w"
# "model/h.*/mlp/c_proj/w"
#if name[-14:] == "/attn/c_attn/w" or \
# name[-14:] == "/attn/c_proj/w" or \
# name[-11:] == "/mlp/c_fc/w" or \
# name[-13:] == "/mlp/c_proj/w":
# print(" Transposing")
# data = data.transpose()
dshape = data.shape
# default type is fp16
ftype_cur = 1
if ftype == 0 or n_dims == 1:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
# header
str = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
fout.write(str);
# data
data.tofile(fout)
fout.close()
print("Done. Output file: " + fname_out)
print("")
for p in range(n_parts):
print('Processing part ', p)
#fname_model = sys.argv[1] + "/consolidated.00.pth"
fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth"
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
if (p > 0):
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p)
model = torch.load(fname_model, map_location="cpu")
fout = open(fname_out, "wb")
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
fout.write(struct.pack("i", hparams["vocab_size"]))
fout.write(struct.pack("i", hparams["dim"]))
fout.write(struct.pack("i", hparams["multiple_of"]))
fout.write(struct.pack("i", hparams["n_heads"]))
fout.write(struct.pack("i", hparams["n_layers"]))
fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete)
fout.write(struct.pack("i", ftype))
# Is this correct??
for i in range(32000):
# TODO: this is probably wrong - not sure how this tokenizer works
text = tokenizer.decode([29889, i]).encode('utf-8')
# remove the first byte (it's always '.')
text = text[1:]
fout.write(struct.pack("i", len(text)))
fout.write(text)
for k, v in model.items():
name = k
shape = v.shape
# skip layers.X.attention.inner_attention.rope.freqs
if name[-5:] == "freqs":
continue
print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)
#data = tf.train.load_variable(dir_model, name).squeeze()
data = v.numpy().squeeze()
n_dims = len(data.shape);
# for efficiency - transpose some matrices
# "model/h.*/attn/c_attn/w"
# "model/h.*/attn/c_proj/w"
# "model/h.*/mlp/c_fc/w"
# "model/h.*/mlp/c_proj/w"
#if name[-14:] == "/attn/c_attn/w" or \
# name[-14:] == "/attn/c_proj/w" or \
# name[-11:] == "/mlp/c_fc/w" or \
# name[-13:] == "/mlp/c_proj/w":
# print(" Transposing")
# data = data.transpose()
dshape = data.shape
# default type is fp16
ftype_cur = 1
if ftype == 0 or n_dims == 1:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
# header
sname = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
fout.write(sname);
# data
data.tofile(fout)
# I hope this deallocates the memory ..
model = None
fout.close()
print("Done. Output file: " + fname_out + ", (part ", p, ")")
print("")

@ -366,9 +366,10 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
assert(k % QK == 0);
const int nb = k / QK;
const size_t bs = sizeof(float) + QK/2;
float * restrict pd = (float *) (y);
uint8_t * restrict pb = (uint8_t *) (pd + nb);
uint8_t * restrict pd = (uint8_t *) (y + 0*bs);
uint8_t * restrict pb = (uint8_t *) (y + 0*bs + sizeof(float));
uint8_t pp[QK/2];
@ -395,7 +396,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0/d : 0.0;
pd[i] = d;
*(float *)pd = d;
pd += bs;
for (int l = 0; l < 8; l++) {
const float32x4_t v = vmulq_n_f32(srcv[l], id);
@ -406,7 +408,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
}
memcpy(pb + i*16, pp, sizeof(pp));
memcpy(pb, pp, sizeof(pp));
pb += bs;
}
#else
#error "not implemented for QK"
@ -434,7 +437,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0/d : 0.0;
pd[i] = d;
*(float *)pd = d;
pd += bs;
for (int l = 0; l < 8; l++) {
const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
@ -445,7 +449,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
}
memcpy(pb + i*16, pp, sizeof(pp));
memcpy(pb, pp, sizeof(pp));
pb += bs;
}
#else
#error "not implemented for QK"
@ -463,7 +468,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0f/d : 0.0f;
pd[i] = d;
*(float *)pd = d;
pd += bs;
for (int l = 0; l < QK; l += 2) {
const float v0 = x[i*QK + l + 0]*id;
@ -478,7 +484,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
pp[l/2] = vi0 | (vi1 << 4);
}
memcpy(pb + i*QK/2, pp, sizeof(pp));
memcpy(pb, pp, sizeof(pp));
pb += bs;
}
#endif
}
@ -535,15 +542,16 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
assert(k % QK == 0);
const int nb = k / QK;
const size_t bs = sizeof(float) + QK/2;
const float * restrict pd = (const float *) (x);
const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
const uint8_t * restrict pd = (const uint8_t *) (x + 0*bs);
const uint8_t * restrict pb = (const uint8_t *) (x + 0*bs + sizeof(float));
// scalar
for (int i = 0; i < nb; i++) {
const float d = pd[i];
const float d = *(const float *) (pd + i*bs);
const uint8_t * restrict pp = pb + i*QK/2;
const uint8_t * restrict pp = pb + i*bs;
for (int l = 0; l < QK; l += 2) {
const uint8_t vi = pp[l/2];
@ -554,6 +562,8 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d;
//printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
y[i*QK + l + 0] = v0;
y[i*QK + l + 1] = v1;
@ -1179,11 +1189,13 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
assert(n % QK == 0);
assert(nb % 2 == 0);
const float * restrict pd0 = (const float *) x;
const float * restrict pd1 = (const float *) y;
const size_t bs = sizeof(float) + QK/2;
const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
const uint8_t * restrict pd0 = (const uint8_t *) (x + 0*bs);
const uint8_t * restrict pd1 = (const uint8_t *) (y + 0*bs);
const uint8_t * restrict pb0 = (const uint8_t *) (x + 0*bs + sizeof(float));
const uint8_t * restrict pb1 = (const uint8_t *) (y + 0*bs + sizeof(float));
float sumf = 0.0;
@ -1193,23 +1205,23 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
float sum1 = 0.0f;
for (int i = 0; i < nb; i += 2) {
const float d0_0 = pd0[i + 0];
const float d1_0 = pd1[i + 0];
const float d0_1 = pd0[i + 1];
const float d1_1 = pd1[i + 1];
const float d0_0 = *(const float *) (pd0 + i*bs);
const float d1_0 = *(const float *) (pd1 + i*bs);
const float d0_1 = *(const float *) (pd0 + (i + 1)*bs);
const float d1_1 = *(const float *) (pd1 + (i + 1)*bs);
//printf("d0_0: %f, d1_0: %f, d0_1: %f, d1_1: %f\n", d0_0, d1_0, d0_1, d1_1);
const uint8_t * restrict p0 = pb0 + i*16;
const uint8_t * restrict p1 = pb1 + i*16;
const uint8_t * restrict p0 = pb0 + i*bs;
const uint8_t * restrict p1 = pb1 + i*bs;
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t v0_0 = vld1q_u8(p0);
const uint8x16_t v1_0 = vld1q_u8(p1);
const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
const uint8x16_t v0_1 = vld1q_u8(p0 + bs);
const uint8x16_t v1_1 = vld1q_u8(p1 + bs);
// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
@ -1280,21 +1292,21 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
float sum1 = 0.0f;
for (int i = 0; i < nb; i += 2) {
const float d0_0 = pd0[i + 0];
const float d0_1 = pd0[i + 1];
const float d1_0 = pd1[i + 0];
const float d1_1 = pd1[i + 1];
const float d0_0 = *(const float *) (pd0 + i*bs);
const float d1_0 = *(const float *) (pd1 + i*bs);
const float d0_1 = *(const float *) (pd0 + (i + 1)*bs);
const float d1_1 = *(const float *) (pd1 + (i + 1)*bs);
const uint8_t * restrict p0 = pb0 + i*16;
const uint8_t * restrict p1 = pb1 + i*16;
const uint8_t * restrict p0 = pb0 + i*bs;
const uint8_t * restrict p1 = pb1 + i*bs;
const v128_t m4b = wasm_u8x16_splat(0xf);
const v128_t s8b = wasm_i8x16_splat(0x8);
const v128_t v0_0 = wasm_v128_load(p0);
const v128_t v0_1 = wasm_v128_load(p0 + 16);
const v128_t v0_1 = wasm_v128_load(p0 + bs);
const v128_t v1_0 = wasm_v128_load(p1);
const v128_t v1_1 = wasm_v128_load(p1 + 16);
const v128_t v1_1 = wasm_v128_load(p1 + bs);
// 4-bit -> 8-bit
const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
@ -1363,11 +1375,11 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
#else
// scalar
for (int i = 0; i < nb; i++) {
const float d0 = pd0[i];
const float d1 = pd1[i];
const float d0 = *(const float *) (pd0 + i*bs);
const float d1 = *(const float *) (pd1 + i*bs);
const uint8_t * restrict p0 = pb0 + i*QK/2;
const uint8_t * restrict p1 = pb1 + i*QK/2;
const uint8_t * restrict p0 = pb0 + i*bs;
const uint8_t * restrict p1 = pb1 + i*bs;
for (int j = 0; j < QK/2; j++) {
const uint8_t v0 = p0[j];
@ -1552,16 +1564,17 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
assert(n % QK == 0);
const int nb = n / QK;
const size_t bs = sizeof(float) + QK/2;
const float * restrict pd = (const float *) (x);
const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
const uint8_t * restrict pd = (const uint8_t *) (x + 0*bs);
const uint8_t * restrict pb = (const uint8_t *) (x + 0*bs + sizeof(float));
#if __ARM_NEON
#if QK == 32
for (int i = 0; i < nb; ++i) {
const float d0 = pd[i]*v;
const float d0 = v*(*(const float *) (pd + i*bs));
const uint8_t * restrict pp = pb + i*16;
const uint8_t * restrict pp = pb + i*bs;
const uint8x8_t m4b = vdup_n_u8(0xf);
const int8x8_t s8b = vdup_n_s8(0x8);
@ -1615,9 +1628,9 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
#else
// scalar
for (int i = 0; i < nb; i++) {
const float d = pd[i];
const float d = *(const float *) (pd + i*bs);
const uint8_t * restrict pp = pb + i*QK/2;
const uint8_t * restrict pp = pb + i*bs;
for (int l = 0; l < QK; l += 2) {
const uint8_t vi = pp[l/2];

@ -11,6 +11,14 @@
#include <string>
#include <vector>
// determine number of model parts based on the dimension
static const std::map<int, int> LLAMA_N_PARTS = {
{ 4096, 1 },
{ 5120, 2 },
{ 6656, 4 },
{ 8192, 8 },
};
// default hparams (LLaMA 7B)
struct llama_hparams {
int32_t n_vocab = 32000;
@ -82,6 +90,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
}
int n_ff = 0;
int n_parts = 0;
// load hparams
{
@ -99,6 +108,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
hparams.n_ctx = n_ctx;
n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
@ -109,6 +119,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
printf("%s: n_rot = %d\n", __func__, hparams.n_rot);
printf("%s: f16 = %d\n", __func__, hparams.f16);
printf("%s: n_ff = %d\n", __func__, n_ff);
printf("%s: n_parts = %d\n", __func__, n_parts);
}
// load vocab
@ -220,7 +231,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
model.layers.resize(n_layer);
model.tok_embeddings = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
model.tok_embeddings = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
model.norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.output = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
@ -234,14 +245,14 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];
layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.wq = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.wk = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.wv = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.wo = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.wq = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.wk = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.wv = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.wo = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.w1 = ggml_new_tensor_2d(ctx, wtype, n_embd, n_ff);
layer.w2 = ggml_new_tensor_2d(ctx, wtype, n_ff, n_embd);
@ -282,94 +293,208 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
printf("%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
}
// load weights
{
int n_tensors = 0;
size_t total_size = 0;
const size_t file_offset = fin.tellg();
printf("%s: ", __func__);
fin.close();
while (true) {
int32_t n_dims;
int32_t length;
int32_t ftype;
std::vector<uint8_t> tmp;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
for (int i = 0; i < n_parts; ++i) {
const int part_id = i;
//const int part_id = n_parts - i - 1;
if (fin.eof()) {
break;
}
std::string fname_part = fname;
if (i > 0) {
fname_part += "." + std::to_string(i);
}
int32_t nelements = 1;
int32_t ne[2] = { 1, 1 };
for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i];
}
printf("%s: loading model part %d/%d from '%s'\n", __func__, i+1, n_parts, fname_part.c_str());
std::string name(length, 0);
fin.read(&name[0], length);
fin = std::ifstream(fname_part, std::ios::binary);
fin.seekg(file_offset);
if (model.tensors.find(name.data()) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false;
}
// load weights
{
int n_tensors = 0;
size_t total_size = 0;
auto tensor = model.tensors[name.data()];
if (ggml_nelements(tensor) != nelements) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
return false;
}
printf("%s: ", __func__);
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
return false;
}
while (true) {
int32_t n_dims;
int32_t length;
int32_t ftype;
if (0) {
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ftype_str[ftype], ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
}
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
if (fin.eof()) {
break;
}
int32_t nelements = 1;
int32_t ne[2] = { 1, 1 };
for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i];
}
size_t bpe = 0;
std::string name(length, 0);
fin.read(&name[0], length);
switch (ftype) {
case 0: bpe = ggml_type_size(GGML_TYPE_F32); break;
case 1: bpe = ggml_type_size(GGML_TYPE_F16); break;
case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
default:
{
fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
if (model.tensors.find(name.data()) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false;
}
// split_type = 0: split by columns
// split_type = 1: split by rows
int split_type = 0;
// split_type = 0:
// regex:
// - tok_embeddings.*
// - layers.*.attention.wo.weight
// - layers.*.feed_forward.w2.weight
// split_type = 1:
// regex:
// - output.*
// - layers.*.attention.wq.weight
// - layers.*.attention.wk.weight
// - layers.*.attention.wv.weight
// - layers.*.feed_forward.w1.weight
// - layers.*.feed_forward.w3.weight
if (name.find("tok_embeddings") != std::string::npos) {
split_type = 0;
} else if (name.find("layers") != std::string::npos) {
if (name.find("attention.wo.weight") != std::string::npos) {
split_type = 0;
} else if (name.find("feed_forward.w2.weight") != std::string::npos) {
split_type = 0;
} else {
split_type = 1;
}
} else if (name.find("output") != std::string::npos) {
split_type = 1;
}
auto tensor = model.tensors[name.data()];
if (n_dims == 1) {
if (ggml_nelements(tensor) != nelements) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
return false;
}
} else {
if (ggml_nelements(tensor)/n_parts != nelements) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
return false;
}
}
if (n_dims == 1) {
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
return false;
}
} else {
if (split_type == 0) {
if (tensor->ne[0]/n_parts != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), tensor->ne[0]/n_parts, tensor->ne[1], ne[0], ne[1]);
return false;
}
} else {
if (tensor->ne[0] != ne[0] || tensor->ne[1]/n_parts != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), tensor->ne[0], tensor->ne[1]/n_parts, ne[0], ne[1]);
return false;
}
};
}
}
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
if (0) {
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
printf("%24s - [%5d, %5d], type = %6s, split = %d\n", name.data(), ne[0], ne[1], ftype_str[ftype], split_type);
}
size_t bpe = 0;
switch (ftype) {
case 0: bpe = ggml_type_size(GGML_TYPE_F32); break;
case 1: bpe = ggml_type_size(GGML_TYPE_F16); break;
case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
default:
{
fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
return false;
}
};
if (n_dims == 1 || n_parts == 1) {
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
if (part_id == 0) {
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
} else {
fin.seekg(ggml_nbytes(tensor), std::ios::cur);
}
total_size += ggml_nbytes(tensor);
} else {
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)/n_parts) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor)/n_parts, nelements*bpe);
return false;
}
if (split_type == 0) {
const int np0 = ne[0];
const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
assert(row_size == tensor->nb[1]);
for (int i1 = 0; i1 < ne[1]; ++i1) {
const size_t offset_row = i1*row_size;
const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
fin.read(reinterpret_cast<char *>(tensor->data) + offset, row_size/n_parts);
}
} else {
const int np1 = ne[1];
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
//printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
total_size += ggml_nbytes(tensor);
if (++n_tensors % 8 == 0) {
printf(".");
fflush(stdout);
for (int i1 = 0; i1 < ne[1]; ++i1) {
const size_t offset_row = (i1 + part_id*np1)*row_size;
fin.read(reinterpret_cast<char *>(tensor->data) + offset_row, row_size);
}
}
total_size += ggml_nbytes(tensor)/n_parts;
}
//printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
if (++n_tensors % 8 == 0) {
printf(".");
fflush(stdout);
}
}
}
printf(" done\n");
printf(" done\n");
printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
}
printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
}
fin.close();
fin.close();
}
return true;
}

@ -448,7 +448,8 @@ gpt_vocab::id llama_sample_top_p(
size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t * hist) {
const int nb = k / qk;
const size_t row_size = nb*(sizeof(float) + sizeof(uint8_t)*qk/2);
const size_t bs = (sizeof(float) + sizeof(uint8_t)*qk/2);
const size_t row_size = nb*bs;
assert(k % qk == 0);
@ -457,8 +458,8 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t
char * pdst = (char *) dst;
for (int j = 0; j < n; j += k) {
float * pd = (float *) (pdst + (j/k)*row_size);
uint8_t * pb = (uint8_t *) (pd + nb);
uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
@ -472,7 +473,8 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t
const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0f/d : 0.0f;
pd[i] = d;
*(float *) pd = d;
pd += bs;
for (int l = 0; l < qk; l += 2) {
const float v0 = (src[j + i*qk + l + 0])*id;
@ -490,7 +492,8 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t
pp[l/2] = vi0 | (vi1 << 4);
}
memcpy(pb + i*qk/2, pp, sizeof(pp));
memcpy(pb, pp, sizeof(pp));
pb += bs;
}
}
}

Loading…
Cancel
Save