ggml : Q4_0 quantization support (ggml_get_rows())

gq
Georgi Gerganov 2 years ago
parent ca2714384b
commit 38faca7efe
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

@ -198,6 +198,8 @@ struct ggml_object;
struct ggml_context;
enum ggml_type {
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
GGML_TYPE_I8,
GGML_TYPE_I16,
GGML_TYPE_I32,
@ -326,6 +328,7 @@ void ggml_print_objects(const struct ggml_context * ctx);
int ggml_nelements(const struct ggml_tensor * tensor);
size_t ggml_nbytes (const struct ggml_tensor * tensor);
int ggml_blck_size (enum ggml_type type);
size_t ggml_type_size (enum ggml_type type);
size_t ggml_element_size(const struct ggml_tensor * tensor);

@ -348,6 +348,166 @@ int64_t ggml_cycles_per_ms(void) {
static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
//
// quantization
//
// method 5
// blocks of 64 elements
// represented with a single float (delta) and 32 8-bit ints (i.e 64 4-bit signed integer factors)
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
assert(k % 64 == 0);
const int nb = k / 64;
float * restrict pd = (float *) (y);
uint8_t * restrict pb = (uint8_t *) (pd + nb);
uint8_t pp[32];
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
#if defined(__AVX2__)
{
__m256 srcv [8];
__m256 asrcv[8];
__m256 amaxv[8];
// TODO: unroll these loops (maybe not? this is more compact)
for (int l = 0; l < 8; l++) srcv[l] = _mm256_loadu_ps(x + i*QK + 8*l);
for (int l = 0; l < 8; l++) asrcv[l] = _mm256_and_ps(srcv[l], (__m256) _mm256_set1_epi32(0x7fffffff));
for (int l = 0; l < 4; l++) amaxv[2*l] = _mm256_max_ps(asrcv[2*l], asrcv[2*l+1]);
for (int l = 0; l < 2; l++) amaxv[4*l] = _mm256_max_ps(amaxv[4*l], amaxv[4*l+2]);
for (int l = 0; l < 1; l++) amaxv[8*l] = _mm256_max_ps(amaxv[8*l], amaxv[8*l+4]);
const __m256 amaxv0_0 = _mm256_permute2f128_ps(amaxv[0], amaxv[0], 3);
const __m256 amaxv0_1 = _mm256_max_ps(amaxv[0], amaxv0_0);
const __m256 amaxv0_2 = _mm256_permute_ps(amaxv0_1, 0x4e);
const __m256 amaxv0_3 = _mm256_max_ps(amaxv0_1, amaxv0_2);
const __m256 amaxv0_4 = _mm256_permute_ps(amaxv0_3, 0xb1);
const __m256 amaxv0_5 = _mm256_max_ps(amaxv0_3, amaxv0_4);
amax = _mm256_cvtss_f32(amaxv0_5);
//printf("amax = %f\n", amax);
const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0f/d : 0.0f;
pd[i] = d;
const __m256 idv = _mm256_set1_ps(id);
for (int l = 0; l < 8; l++) {
__m256 v = _mm256_mul_ps(srcv[l], idv);
#if 0
v[0] += frand(); v[1] += frand(); v[2] += frand(); v[3] += frand();
v[4] += frand(); v[5] += frand(); v[6] += frand(); v[7] += frand();
#endif
// convert to int8
__m256i vi = _mm256_cvtps_epi32(v);
vi = _mm256_add_epi32(vi, _mm256_set1_epi32(8));
int32_t vi_0 = _mm256_extract_epi32(vi, 0);
int32_t vi_1 = _mm256_extract_epi32(vi, 1);
int32_t vi_2 = _mm256_extract_epi32(vi, 2);
int32_t vi_3 = _mm256_extract_epi32(vi, 3);
int32_t vi_4 = _mm256_extract_epi32(vi, 4);
int32_t vi_5 = _mm256_extract_epi32(vi, 5);
int32_t vi_6 = _mm256_extract_epi32(vi, 6);
int32_t vi_7 = _mm256_extract_epi32(vi, 7);
// convert to 4-bit, 2 consecutive packed into 1 byte
pp[4*l + 0] = vi_0 | (vi_1 << 4);
pp[4*l + 1] = vi_2 | (vi_3 << 4);
pp[4*l + 2] = vi_4 | (vi_5 << 4);
pp[4*l + 3] = vi_6 | (vi_7 << 4);
//printf("vi: %7d %7d %7d %7d %7d %7d %7d %7d\n", vi_0, vi_1, vi_2, vi_3, vi_4, vi_5, vi_6, vi_7);
//printf("v : %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f\n", v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]);
assert(vi_0 >= 0 && vi_0 < 16);
assert(vi_1 >= 0 && vi_1 < 16);
assert(vi_2 >= 0 && vi_2 < 16);
assert(vi_3 >= 0 && vi_3 < 16);
assert(vi_4 >= 0 && vi_4 < 16);
assert(vi_5 >= 0 && vi_5 < 16);
assert(vi_6 >= 0 && vi_6 < 16);
assert(vi_7 >= 0 && vi_7 < 16);
}
memcpy(pb + i*32, pp, sizeof(pp));
}
#elif defined(__ARM_NEON) && 0
{
// TODO
#pragma warning "implement me !!"
}
#else
{
for (int l = 0; l < 64; l++) {
const float v = x[i*64 + l];
amax = MAX(amax, fabsf(v));
}
const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0f/d : 0.0f;
pd[i] = d;
for (int l = 0; l < 64; l += 2) {
const float v0 = x[i*64 + l + 0]*id;
const float v1 = x[i*64 + l + 1]*id;
const uint8_t vi0 = ((int8_t) (round(v0))) + 8;
const uint8_t vi1 = ((int8_t) (round(v1))) + 8;
assert(vi0 >= 0 && vi0 < 16);
assert(vi1 >= 0 && vi1 < 16);
pp[l] = vi0 | (vi1 << 4);
}
memcpy(pb + i*32, pp, sizeof(pp));
}
#endif
//printf("min %f max %f\n", min, max);
}
}
void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
assert(k % 64 == 0);
const int nb = k / 64;
const float * restrict pd = (const float *) (x);
const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
for (int i = 0; i < nb; i++) {
const float d = pd[i];
const uint8_t * restrict pp = pb + i*32;
for (int l = 0; l < 64; l += 2) {
const uint8_t vi = pp[l/2];
const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;
const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d;
y[i*64 + l + 0] = v0;
y[i*64 + l + 1] = v1;
}
}
}
//
// simd mappings
//
@ -1165,7 +1325,21 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
// data types
//
static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
64,
64,
1,
1,
1,
1,
1,
};
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
sizeof(float ) + 32,
sizeof(float )*2 + 32,
sizeof(int8_t ),
sizeof(int16_t),
sizeof(int32_t),
@ -1173,6 +1347,9 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
sizeof(float ),
};
// don't forget to update the array above when adding new types
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"NONE",
@ -1213,6 +1390,8 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"FLASH_FF",
};
static_assert(GGML_OP_COUNT == 33, "GGML_OP_COUNT != 33");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -1253,6 +1432,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"flash_ff(x)",
};
static_assert(GGML_OP_COUNT == 33, "GGML_OP_COUNT != 33");
//
// ggml object
//
@ -1380,7 +1561,11 @@ int ggml_nrows(const struct ggml_tensor * tensor) {
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type];
return (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
}
int ggml_blck_size(enum ggml_type type) {
return GGML_BLCK_SIZE[type];
}
size_t ggml_type_size(enum ggml_type type) {
@ -1814,6 +1999,14 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
char * const data = tensor->data;
switch (tensor->type) {
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_1:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
assert(tensor->nb[0] == sizeof(int8_t));
@ -1866,6 +2059,14 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
char * const data = tensor->data;
switch (tensor->type) {
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_1:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
assert(tensor->nb[0] == sizeof(int8_t));
@ -1912,6 +2113,14 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
switch (tensor->type) {
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_1:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@ -1948,6 +2157,14 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
switch (tensor->type) {
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_1:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@ -1982,6 +2199,14 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
switch (tensor->type) {
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_1:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@ -2018,6 +2243,14 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
switch (tensor->type) {
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_1:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@ -3435,6 +3668,8 @@ static void ggml_compute_forward_dup(
{
ggml_compute_forward_dup_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -3510,6 +3745,8 @@ static void ggml_compute_forward_add(
{
ggml_compute_forward_add_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -3560,6 +3797,8 @@ static void ggml_compute_forward_sub(
{
ggml_compute_forward_sub_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -3610,6 +3849,8 @@ static void ggml_compute_forward_mul(
{
ggml_compute_forward_mul_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -3660,6 +3901,8 @@ static void ggml_compute_forward_div(
{
ggml_compute_forward_div_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -3706,6 +3949,8 @@ static void ggml_compute_forward_sqr(
{
ggml_compute_forward_sqr_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -3752,6 +3997,8 @@ static void ggml_compute_forward_sqrt(
{
ggml_compute_forward_sqrt_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -3808,6 +4055,8 @@ static void ggml_compute_forward_sum(
{
ggml_compute_forward_sum_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -3883,6 +4132,8 @@ static void ggml_compute_forward_mean(
{
ggml_compute_forward_mean_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -3945,6 +4196,8 @@ static void ggml_compute_forward_repeat(
{
ggml_compute_forward_repeat_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -3991,6 +4244,8 @@ static void ggml_compute_forward_abs(
{
ggml_compute_forward_abs_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -4037,6 +4292,8 @@ static void ggml_compute_forward_sgn(
{
ggml_compute_forward_sgn_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -4083,6 +4340,8 @@ static void ggml_compute_forward_neg(
{
ggml_compute_forward_neg_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -4129,6 +4388,8 @@ static void ggml_compute_forward_step(
{
ggml_compute_forward_step_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -4175,6 +4436,8 @@ static void ggml_compute_forward_relu(
{
ggml_compute_forward_relu_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -4238,6 +4501,8 @@ static void ggml_compute_forward_gelu(
{
ggml_compute_forward_gelu_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -4320,6 +4585,8 @@ static void ggml_compute_forward_norm(
{
ggml_compute_forward_norm_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -4916,6 +5183,10 @@ static void ggml_compute_forward_mul_mat(
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(false); // TODO: implement
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst);
@ -4924,6 +5195,7 @@ static void ggml_compute_forward_mul_mat(
{
ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -4981,6 +5253,8 @@ static void ggml_compute_forward_scale(
{
ggml_compute_forward_scale_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5073,6 +5347,33 @@ static void ggml_compute_forward_get_rows_f16(
}
}
static void ggml_compute_forward_get_rows_q4_0(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
assert(params->ith == 0);
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
const int nc = src0->ne[0];
const int nr = ggml_nelements(src1);
assert( dst->ne[0] == nc);
assert( dst->ne[1] == nr);
assert(src0->nb[0] == GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
for (int i = 0; i < nr; ++i) {
const int r = ((int32_t *) src1->data)[i];
dequantize_row_q4_0(
(const void *) ((char *) src0->data + r*src0->nb[1]),
(float *) ((char *) dst->data + i*dst->nb[1]), nc);
}
}
static void ggml_compute_forward_get_rows_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
@ -5106,6 +5407,10 @@ static void ggml_compute_forward_get_rows(
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_Q4_0:
{
ggml_compute_forward_get_rows_q4_0(params, src0, src1, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_get_rows_f16(params, src0, src1, dst);
@ -5114,6 +5419,7 @@ static void ggml_compute_forward_get_rows(
{
ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5172,6 +5478,8 @@ static void ggml_compute_forward_diag_mask_inf(
{
ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5263,6 +5571,8 @@ static void ggml_compute_forward_soft_max(
{
ggml_compute_forward_soft_max_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5343,6 +5653,8 @@ static void ggml_compute_forward_rope(
{
ggml_compute_forward_rope_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5610,6 +5922,8 @@ static void ggml_compute_forward_conv_1d_1s(
{
ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5876,6 +6190,8 @@ static void ggml_compute_forward_conv_1d_2s(
{
ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -6359,6 +6675,8 @@ static void ggml_compute_forward_flash_attn(
{
ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -6568,6 +6886,8 @@ static void ggml_compute_forward_flash_ff(
{
GGML_ASSERT(false); // TODO
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -7338,19 +7658,22 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
node->n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]);
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
//printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
//printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
//printf("cur = %zu\n", cur);
} else {
cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
cur = GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1);
}
#else
cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
#endif
} else if (node->src0->type == GGML_TYPE_F32 &&
node->src1->type == GGML_TYPE_F32) {
cur = 0;
} else if (node->src0->type == GGML_TYPE_Q4_0 &&
node->src1->type == GGML_TYPE_F32) {
cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
} else {
GGML_ASSERT(false);
}

Loading…
Cancel
Save