diff --git a/ggml.c b/ggml.c index 00cfc26..d67612c 100644 --- a/ggml.c +++ b/ggml.c @@ -1286,6 +1286,7 @@ struct ggml_context { struct ggml_object * objects_end; struct ggml_scratch scratch; + struct ggml_scratch scratch_save; }; struct ggml_context_container { @@ -1551,6 +1552,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { /*.objects_begin =*/ NULL, /*.objects_end =*/ NULL, /*.scratch =*/ { 0, 0, NULL, }, + /*.scratch_save =*/ { 0, 0, NULL, }, }; ggml_assert_aligned(ctx->mem_buffer); @@ -1770,13 +1772,12 @@ struct ggml_tensor * ggml_new_tensor_4d( } struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { - // TODO: FIX ME !!!!!!!!!!!!!!!!!!!!!!!! - void * tmp = ctx->scratch.data; + ctx->scratch_save = ctx->scratch; ctx->scratch.data = NULL; struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); - ctx->scratch.data = tmp; + ctx->scratch = ctx->scratch_save; ggml_set_i32(result, value); @@ -1784,13 +1785,12 @@ struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { } struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) { - // TODO: FIX ME !!!!!!!!!!!!!!!!!!!!!!!! - void * tmp = ctx->scratch.data; + ctx->scratch_save = ctx->scratch; ctx->scratch.data = NULL; struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); - ctx->scratch.data = tmp; + ctx->scratch = ctx->scratch_save; ggml_set_f32(result, value);