You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ggml/tests/test3.c

96 lines
2.7 KiB

#include "ggml/ggml.h"
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
bool is_close(float a, float b, float epsilon) {
return fabs(a - b) < epsilon;
}
int main(int argc, const char ** argv) {
struct ggml_init_params params = {
.mem_size = 1024*1024*1024,
.mem_buffer = NULL,
};
struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_LBFGS);
//struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_ADAM);
opt_params.n_threads = (argc > 1) ? atoi(argv[1]) : 8;
const int NP = 1 << 12;
const int NF = 1 << 8;
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_tensor * F = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, NF, NP);
struct ggml_tensor * l = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, NP);
// regularization weight
struct ggml_tensor * lambda = ggml_new_f32(ctx0, 1e-5f);
srand(0);
for (int j = 0; j < NP; j++) {
const float ll = j < NP/2 ? 1.0f : -1.0f;
((float *)l->data)[j] = ll;
for (int i = 0; i < NF; i++) {
((float *)F->data)[j*NF + i] = ((ll > 0 && i < NF/2 ? 1.0f : ll < 0 && i >= NF/2 ? 1.0f : 0.0f) + ((float)rand()/(float)RAND_MAX - 0.5f)*0.1f)/(0.5f*NF);
}
}
{
// initial guess
struct ggml_tensor * x = ggml_set_f32(ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, NF), 0.0f);
ggml_set_param(ctx0, x);
// f = sum[(fj*x - l)^2]/n + lambda*|x^2|
struct ggml_tensor * f =
ggml_add(ctx0,
ggml_div(ctx0,
ggml_sum(ctx0,
ggml_sqr(ctx0,
ggml_sub(ctx0,
ggml_mul_mat(ctx0, F, x),
l)
)
),
ggml_new_f32(ctx0, NP)
),
ggml_mul(ctx0,
ggml_sum(ctx0, ggml_sqr(ctx0, x)),
lambda)
);
enum ggml_opt_result res = ggml_opt(NULL, opt_params, f);
assert(res == GGML_OPT_OK);
// print results
for (int i = 0; i < 16; i++) {
printf("x[%3d] = %g\n", i, ((float *)x->data)[i]);
}
printf("...\n");
for (int i = NF - 16; i < NF; i++) {
printf("x[%3d] = %g\n", i, ((float *)x->data)[i]);
}
printf("\n");
for (int i = 0; i < NF; ++i) {
if (i < NF/2) {
assert(is_close(((float *)x->data)[i], 1.0f, 1e-2f));
} else {
assert(is_close(((float *)x->data)[i], -1.0f, 1e-2f));
}
}
}
ggml_free(ctx0);
return 0;
}