ggml : initial tests with libnvblas

pull/239/head
Georgi Gerganov 2 years ago
parent 3996ecc156
commit 683f111088
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -96,6 +96,8 @@ typedef void* thread_ret_t;
#include <Accelerate/Accelerate.h>
#elif GGML_USE_OPENBLAS
#include <cblas.h>
// sgemm
extern void sgemm_(char* transa, char* transb, int* m, int* n, int* k, float* alpha, float* a, int* lda, float* b, int* ldb, float* beta, float* c, int* ldc);
#endif
// floating point type used to accumulate sums
@ -4588,11 +4590,23 @@ void ggml_compute_forward_mul_mat_f16_f32(
// zT = y * xT
{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
ne11, ne01, ne10,
1.0f, y, ne10,
x, ne10,
0.0f, d, ne01);
//cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
// ne11, ne01, ne10,
// 1.0f, y, ne10,
// x, ne10,
// 0.0f, d, ne01);
// this is compatible with nvblas
float one = 1.0f;
float zero = 0.0f;
sgemm_(
"T", "N",
&ne0, &ne1, &ne10,
&one,
x, &ne10,
y, &ne10,
&zero,
d, &ne0);
}
}
}

Loading…
Cancel
Save