From eaa4006047cae639d1af0f0ab93e78d552aaa100 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 26 Feb 2023 15:20:18 +0200 Subject: [PATCH] gpt : fix memory usage computation --- examples/gpt-2/main.cpp | 36 ++++++++++++++++++------------------ examples/gpt-j/main.cpp | 34 +++++++++++++++++----------------- include/ggml/ggml.h | 6 ++++-- src/ggml.c | 4 ++++ 4 files changed, 43 insertions(+), 37 deletions(-) diff --git a/examples/gpt-2/main.cpp b/examples/gpt-2/main.cpp index 311ffee..f1d3c6f 100644 --- a/examples/gpt-2/main.cpp +++ b/examples/gpt-2/main.cpp @@ -158,32 +158,32 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & const int n_ctx = hparams.n_ctx; const int n_vocab = hparams.n_vocab; - ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g - ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b + ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g + ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b - ctx_size += n_vocab*n_embd*ggml_type_size(wtype); // wte - ctx_size += n_ctx*n_embd*ggml_type_size(GGML_TYPE_F32); // wpe + ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte + ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe - ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g - ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b - ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_g - ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_2_b + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b - ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_size(wtype)); // c_attn_attn_w - ctx_size += n_layer*( 3*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_attn_b + ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w + ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b - ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w - ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_attn_proj_b + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w + ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b - ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w - ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b + ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w + ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b - ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w - ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b + ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w + ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b - ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F16); // memory_k - ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F16); // memory_v + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_k + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_v ctx_size += (6 + 12*n_layer)*256; // object overhead diff --git a/examples/gpt-j/main.cpp b/examples/gpt-j/main.cpp index 544ec3a..d99b9a1 100644 --- a/examples/gpt-j/main.cpp +++ b/examples/gpt-j/main.cpp @@ -160,31 +160,31 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & const int n_ctx = hparams.n_ctx; const int n_vocab = hparams.n_vocab; - ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_g - ctx_size += n_embd*ggml_type_size(GGML_TYPE_F32); // ln_f_b + ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g + ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b - ctx_size += n_embd*n_vocab*ggml_type_size(wtype); // wte + ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // wte - ctx_size += n_embd*n_vocab*ggml_type_size(wtype); // lmh_g - ctx_size += n_vocab*ggml_type_size(GGML_TYPE_F32); // lmh_b + ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // lmh_g + ctx_size += n_vocab*ggml_type_sizef(GGML_TYPE_F32); // lmh_b - ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_g - ctx_size += n_layer*(n_embd*ggml_type_size(GGML_TYPE_F32)); // ln_1_b + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b - ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_q_proj_w - ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_k_proj_w - ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_v_proj_w + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_q_proj_w + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_k_proj_w + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_v_proj_w - ctx_size += n_layer*(n_embd*n_embd*ggml_type_size(wtype)); // c_attn_proj_w + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w - ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_fc_w - ctx_size += n_layer*( 4*n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_fc_b + ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w + ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b - ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_size(wtype)); // c_mlp_proj_w_trans - ctx_size += n_layer*( n_embd*ggml_type_size(GGML_TYPE_F32)); // c_mlp_proj_b + ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w_trans + ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b - ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F16); // memory_k - ctx_size += n_ctx*n_layer*n_embd*ggml_type_size(GGML_TYPE_F16); // memory_v + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_k + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_v ctx_size += (5 + 10*n_layer)*256; // object overhead diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 1b8093d..cdd8553 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -328,8 +328,10 @@ void ggml_print_objects(const struct ggml_context * ctx); int ggml_nelements(const struct ggml_tensor * tensor); size_t ggml_nbytes (const struct ggml_tensor * tensor); -int ggml_blck_size (enum ggml_type type); -size_t ggml_type_size (enum ggml_type type); +int ggml_blck_size (enum ggml_type type); +size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block +float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float + size_t ggml_element_size(const struct ggml_tensor * tensor); struct ggml_context * ggml_init(struct ggml_init_params params); diff --git a/src/ggml.c b/src/ggml.c index 5d5a9c9..1fc810f 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -1940,6 +1940,10 @@ size_t ggml_type_size(enum ggml_type type) { return GGML_TYPE_SIZE[type]; } +float ggml_type_sizef(enum ggml_type type) { + return ((float)(GGML_TYPE_SIZE[type]))/GGML_BLCK_SIZE[type]; +} + size_t ggml_element_size(const struct ggml_tensor * tensor) { return GGML_TYPE_SIZE[tensor->type]; }