diff --git a/ggml.c b/ggml.c index f379f55..0b91a19 100644 --- a/ggml.c +++ b/ggml.c @@ -100,6 +100,8 @@ typedef void* thread_ret_t; #include #elif GGML_USE_OPENBLAS #include +// sgemm +extern void sgemm_(char* transa, char* transb, int* m, int* n, int* k, float* alpha, float* a, int* lda, float* b, int* ldb, float* beta, float* c, int* ldc); #endif // floating point type used to accumulate sums @@ -4592,11 +4594,23 @@ void ggml_compute_forward_mul_mat_f16_f32( // zT = y * xT { - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - ne11, ne01, ne10, - 1.0f, y, ne10, - x, ne10, - 0.0f, d, ne01); + //cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + // ne11, ne01, ne10, + // 1.0f, y, ne10, + // x, ne10, + // 0.0f, d, ne01); + + // this is compatible with nvblas + float one = 1.0f; + float zero = 0.0f; + sgemm_( + "T", "N", + &ne0, &ne1, &ne10, + &one, + x, &ne10, + y, &ne10, + &zero, + d, &ne0); } } }