ggml : add "scratch" buffer support

pull/431/head
Georgi Gerganov 2 years ago
parent 291980369c
commit 60eff46b0a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -1284,6 +1284,8 @@ struct ggml_context {
struct ggml_object * objects_begin;
struct ggml_object * objects_end;
struct ggml_scratch scratch;
};
struct ggml_context_container {
@ -1542,12 +1544,13 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
}
*ctx = (struct ggml_context) {
.mem_size = params.mem_size,
.mem_buffer = params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
.mem_buffer_owned = params.mem_buffer ? false : true,
.n_objects = 0,
.objects_begin = NULL,
.objects_end = NULL,
/*.mem_size =*/ params.mem_size,
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
/*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
/*.n_objects =*/ 0,
/*.objects_begin =*/ NULL,
/*.objects_end =*/ NULL,
/*.scratch =*/ { 0, 0, NULL, 0, NULL },
};
ggml_assert_aligned(ctx->mem_buffer);
@ -1592,6 +1595,10 @@ size_t ggml_used_mem(const struct ggml_context * ctx) {
return ctx->objects_end->offset + ctx->objects_end->size;
}
void ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) {
ctx->scratch = scratch;
}
////////////////////////////////////////////////////////////////////////////////
struct ggml_tensor * ggml_new_tensor_impl(
@ -1620,21 +1627,42 @@ struct ggml_tensor * ggml_new_tensor_impl(
}
size_needed += sizeof(struct ggml_tensor);
if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
GGML_PRINT("%s: not enough space in the context's memory pool\n", __func__);
assert(false);
return NULL;
}
char * const mem_buffer = ctx->mem_buffer;
struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
*obj_new = (struct ggml_object) {
.offset = cur_end + GGML_OBJECT_SIZE,
.size = size_needed,
.next = NULL,
};
size_t scratch_size = ctx->scratch.k & 1 ? ctx->scratch.size1 : ctx->scratch.size0;
void * scratch_data = ctx->scratch.k & 1 ? ctx->scratch.data1 : ctx->scratch.data0;
if (scratch_data == NULL) {
if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
__func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
assert(false);
return NULL;
}
*obj_new = (struct ggml_object) {
.offset = cur_end + GGML_OBJECT_SIZE,
.size = size_needed,
.next = NULL,
};
} else {
if (size_needed > scratch_size) {
GGML_PRINT("%s: not enough space in the scratch memory\n", __func__);
assert(false);
return NULL;
}
data = scratch_data;
*obj_new = (struct ggml_object) {
.offset = cur_end + GGML_OBJECT_SIZE,
.size = sizeof(struct ggml_tensor),
.next = NULL,
};
ctx->scratch.k++;
}
if (obj_cur != NULL) {
obj_cur->next = obj_new;

@ -301,6 +301,17 @@ struct ggml_cgraph {
int64_t perf_time_us;
};
// scratch buffer
struct ggml_scratch {
int k;
size_t size0;
void * data0;
size_t size1;
void * data1;
};
struct ggml_init_params {
// memory pool
size_t mem_size; // bytes
@ -327,6 +338,8 @@ void ggml_free(struct ggml_context * ctx);
size_t ggml_used_mem(const struct ggml_context * ctx);
void ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
struct ggml_tensor * ggml_new_tensor(
struct ggml_context * ctx,
enum ggml_type type,

@ -217,6 +217,14 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
static const size_t MB = 1024*1024;
static const std::map<e_model, size_t> MEM_REQ_SCRATCH = {
{ MODEL_TINY, 32ull*MB },
{ MODEL_BASE, 44ull*MB },
{ MODEL_SMALL, 64ull*MB },
{ MODEL_MEDIUM, 84ull*MB },
{ MODEL_LARGE, 110ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_MODEL = {
{ MODEL_TINY, 74ull*MB },
{ MODEL_BASE, 142ull*MB },
@ -556,6 +564,7 @@ struct whisper_context {
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
// memory buffers used by encode / decode contexts
std::vector<uint8_t> buf_scratch;
std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_compute_layer;
@ -744,8 +753,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
{
// this is the total memory required to run the inference
const size_t mem_required =
scale*MEM_REQ_MODEL.at (model.type) +
scale*MEM_REQ_KV_CROSS.at (model.type) +
MEM_REQ_SCRATCH.at (model.type) +
scale*MEM_REQ_MODEL.at (model.type) +
scale*MEM_REQ_KV_CROSS.at(model.type) +
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)) +
scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type));
@ -783,6 +793,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
}
wctx.buf_scratch.resize (MEM_REQ_SCRATCH.at(model.type));
wctx.buf_compute.resize (scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
}
@ -1335,6 +1346,8 @@ static bool whisper_encode(
struct ggml_tensor * cur;
ggml_set_scratch(ctx0, { 0, wctx.buf_scratch.size()/2, wctx.buf_scratch.data(), wctx.buf_scratch.size()/2, wctx.buf_scratch.data() + wctx.buf_scratch.size()/2 });
// convolution + gelu
{
cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
@ -1344,6 +1357,7 @@ static bool whisper_encode(
cur),
cur);
cur = ggml_gelu(ctx0, cur);
cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
@ -1356,6 +1370,8 @@ static bool whisper_encode(
cur = ggml_gelu(ctx0, cur);
}
ggml_set_scratch(ctx0, { 0, 0, nullptr, 0, nullptr });
// ===================================================================
// NOTE: experimenting with partial evaluation of the encoder (ignore)
//static int iter = -1;
@ -1376,6 +1392,7 @@ static bool whisper_encode(
struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
// ===================================================================
// original:

Loading…
Cancel
Save