@ -366,9 +366,10 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
assert ( k % QK = = 0 ) ;
assert ( k % QK = = 0 ) ;
const int nb = k / QK ;
const int nb = k / QK ;
const size_t bs = sizeof ( float ) + QK / 2 ;
float * restrict pd = ( float * ) ( y ) ;
uint8_t * restrict pd = ( uint8_t * ) ( y + 0 * bs ) ;
uint8_t * restrict pb = ( uint8_t * ) ( pd + nb ) ;
uint8_t * restrict pb = ( uint8_t * ) ( y + 0 * bs + sizeof ( float ) ) ;
uint8_t pp [ QK / 2 ] ;
uint8_t pp [ QK / 2 ] ;
@ -395,7 +396,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
const float d = amax / ( ( 1 < < 3 ) - 1 ) ;
const float d = amax / ( ( 1 < < 3 ) - 1 ) ;
const float id = d ? 1.0 / d : 0.0 ;
const float id = d ? 1.0 / d : 0.0 ;
pd [ i ] = d ;
* ( float * ) pd = d ;
pd + = bs ;
for ( int l = 0 ; l < 8 ; l + + ) {
for ( int l = 0 ; l < 8 ; l + + ) {
const float32x4_t v = vmulq_n_f32 ( srcv [ l ] , id ) ;
const float32x4_t v = vmulq_n_f32 ( srcv [ l ] , id ) ;
@ -406,7 +408,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
pp [ 2 * l + 1 ] = vgetq_lane_s32 ( vi , 2 ) | ( vgetq_lane_s32 ( vi , 3 ) < < 4 ) ;
pp [ 2 * l + 1 ] = vgetq_lane_s32 ( vi , 2 ) | ( vgetq_lane_s32 ( vi , 3 ) < < 4 ) ;
}
}
memcpy ( pb + i * 16 , pp , sizeof ( pp ) ) ;
memcpy ( pb , pp , sizeof ( pp ) ) ;
pb + = bs ;
}
}
# else
# else
# error "not implemented for QK"
# error "not implemented for QK"
@ -434,7 +437,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
const float d = amax / ( ( 1 < < 3 ) - 1 ) ;
const float d = amax / ( ( 1 < < 3 ) - 1 ) ;
const float id = d ? 1.0 / d : 0.0 ;
const float id = d ? 1.0 / d : 0.0 ;
pd [ i ] = d ;
* ( float * ) pd = d ;
pd + = bs ;
for ( int l = 0 ; l < 8 ; l + + ) {
for ( int l = 0 ; l < 8 ; l + + ) {
const v128_t v = wasm_f32x4_mul ( srcv [ l ] , wasm_f32x4_splat ( id ) ) ;
const v128_t v = wasm_f32x4_mul ( srcv [ l ] , wasm_f32x4_splat ( id ) ) ;
@ -445,7 +449,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
pp [ 2 * l + 1 ] = wasm_i32x4_extract_lane ( vi , 2 ) | ( wasm_i32x4_extract_lane ( vi , 3 ) < < 4 ) ;
pp [ 2 * l + 1 ] = wasm_i32x4_extract_lane ( vi , 2 ) | ( wasm_i32x4_extract_lane ( vi , 3 ) < < 4 ) ;
}
}
memcpy ( pb + i * 16 , pp , sizeof ( pp ) ) ;
memcpy ( pb , pp , sizeof ( pp ) ) ;
pb + = bs ;
}
}
# else
# else
# error "not implemented for QK"
# error "not implemented for QK"
@ -463,7 +468,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
const float d = amax / ( ( 1 < < 3 ) - 1 ) ;
const float d = amax / ( ( 1 < < 3 ) - 1 ) ;
const float id = d ? 1.0f / d : 0.0f ;
const float id = d ? 1.0f / d : 0.0f ;
pd [ i ] = d ;
* ( float * ) pd = d ;
pd + = bs ;
for ( int l = 0 ; l < QK ; l + = 2 ) {
for ( int l = 0 ; l < QK ; l + = 2 ) {
const float v0 = x [ i * QK + l + 0 ] * id ;
const float v0 = x [ i * QK + l + 0 ] * id ;
@ -478,7 +484,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
pp [ l / 2 ] = vi0 | ( vi1 < < 4 ) ;
pp [ l / 2 ] = vi0 | ( vi1 < < 4 ) ;
}
}
memcpy ( pb + i * QK / 2 , pp , sizeof ( pp ) ) ;
memcpy ( pb , pp , sizeof ( pp ) ) ;
pb + = bs ;
}
}
# endif
# endif
}
}
@ -535,15 +542,16 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
assert ( k % QK = = 0 ) ;
assert ( k % QK = = 0 ) ;
const int nb = k / QK ;
const int nb = k / QK ;
const size_t bs = sizeof ( float ) + QK / 2 ;
const float * restrict pd = ( const float * ) ( x ) ;
const uint8_t * restrict pd = ( const uint8_t * ) ( x + 0 * bs ) ;
const uint8_t * restrict pb = ( const uint8_t * ) ( pd + nb ) ;
const uint8_t * restrict pb = ( const uint8_t * ) ( x + 0 * bs + sizeof ( float ) ) ;
// scalar
// scalar
for ( int i = 0 ; i < nb ; i + + ) {
for ( int i = 0 ; i < nb ; i + + ) {
const float d = pd [ i ] ;
const float d = * ( const float * ) ( pd + i * bs ) ;
const uint8_t * restrict pp = pb + i * QK/ 2 ;
const uint8_t * restrict pp = pb + i * bs ;
for ( int l = 0 ; l < QK ; l + = 2 ) {
for ( int l = 0 ; l < QK ; l + = 2 ) {
const uint8_t vi = pp [ l / 2 ] ;
const uint8_t vi = pp [ l / 2 ] ;
@ -554,6 +562,8 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
const float v0 = ( vi0 - 8 ) * d ;
const float v0 = ( vi0 - 8 ) * d ;
const float v1 = ( vi1 - 8 ) * d ;
const float v1 = ( vi1 - 8 ) * d ;
//printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
y [ i * QK + l + 0 ] = v0 ;
y [ i * QK + l + 0 ] = v0 ;
y [ i * QK + l + 1 ] = v1 ;
y [ i * QK + l + 1 ] = v1 ;
@ -1179,11 +1189,13 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
assert ( n % QK = = 0 ) ;
assert ( n % QK = = 0 ) ;
assert ( nb % 2 = = 0 ) ;
assert ( nb % 2 = = 0 ) ;
const float * restrict pd0 = ( const float * ) x ;
const size_t bs = sizeof ( float ) + QK / 2 ;
const float * restrict pd1 = ( const float * ) y ;
const uint8_t * restrict pb0 = ( const uint8_t * ) ( pd0 + nb ) ;
const uint8_t * restrict pd0 = ( const uint8_t * ) ( x + 0 * bs ) ;
const uint8_t * restrict pb1 = ( const uint8_t * ) ( pd1 + nb ) ;
const uint8_t * restrict pd1 = ( const uint8_t * ) ( y + 0 * bs ) ;
const uint8_t * restrict pb0 = ( const uint8_t * ) ( x + 0 * bs + sizeof ( float ) ) ;
const uint8_t * restrict pb1 = ( const uint8_t * ) ( y + 0 * bs + sizeof ( float ) ) ;
float sumf = 0.0 ;
float sumf = 0.0 ;
@ -1193,23 +1205,23 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
float sum1 = 0.0f ;
float sum1 = 0.0f ;
for ( int i = 0 ; i < nb ; i + = 2 ) {
for ( int i = 0 ; i < nb ; i + = 2 ) {
const float d0_0 = pd0 [ i + 0 ] ;
const float d0_0 = * ( const float * ) ( pd0 + i * bs ) ;
const float d1_0 = pd1 [ i + 0 ] ;
const float d1_0 = * ( const float * ) ( pd1 + i * bs ) ;
const float d0_1 = pd0 [ i + 1 ] ;
const float d0_1 = * ( const float * ) ( pd0 + ( i + 1 ) * bs ) ;
const float d1_1 = pd1 [ i + 1 ] ;
const float d1_1 = * ( const float * ) ( pd1 + ( i + 1 ) * bs ) ;
//printf("d0_0: %f, d1_0: %f, d0_1: %f, d1_1: %f\n", d0_0, d1_0, d0_1, d1_1);
//printf("d0_0: %f, d1_0: %f, d0_1: %f, d1_1: %f\n", d0_0, d1_0, d0_1, d1_1);
const uint8_t * restrict p0 = pb0 + i * 16 ;
const uint8_t * restrict p0 = pb0 + i * bs ;
const uint8_t * restrict p1 = pb1 + i * 16 ;
const uint8_t * restrict p1 = pb1 + i * bs ;
const uint8x16_t m4b = vdupq_n_u8 ( 0xf ) ;
const uint8x16_t m4b = vdupq_n_u8 ( 0xf ) ;
const int8x16_t s8b = vdupq_n_s8 ( 0x8 ) ;
const int8x16_t s8b = vdupq_n_s8 ( 0x8 ) ;
const uint8x16_t v0_0 = vld1q_u8 ( p0 ) ;
const uint8x16_t v0_0 = vld1q_u8 ( p0 ) ;
const uint8x16_t v1_0 = vld1q_u8 ( p1 ) ;
const uint8x16_t v1_0 = vld1q_u8 ( p1 ) ;
const uint8x16_t v0_1 = vld1q_u8 ( p0 + 16 ) ;
const uint8x16_t v0_1 = vld1q_u8 ( p0 + bs ) ;
const uint8x16_t v1_1 = vld1q_u8 ( p1 + 16 ) ;
const uint8x16_t v1_1 = vld1q_u8 ( p1 + bs ) ;
// 4-bit -> 8-bit
// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8 ( vandq_u8 ( v0_0 , m4b ) ) ;
const int8x16_t v0_0l = vreinterpretq_s8_u8 ( vandq_u8 ( v0_0 , m4b ) ) ;
@ -1280,21 +1292,21 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
float sum1 = 0.0f ;
float sum1 = 0.0f ;
for ( int i = 0 ; i < nb ; i + = 2 ) {
for ( int i = 0 ; i < nb ; i + = 2 ) {
const float d0_0 = pd0 [ i + 0 ] ;
const float d0_0 = * ( const float * ) ( pd0 + i * bs ) ;
const float d 0_1 = pd0 [ i + 1 ] ;
const float d 1_0 = * ( const float * ) ( pd1 + i * bs ) ;
const float d 1_0 = pd1 [ i + 0 ] ;
const float d 0_1 = * ( const float * ) ( pd0 + ( i + 1 ) * bs ) ;
const float d1_1 = pd1 [ i + 1 ] ;
const float d1_1 = * ( const float * ) ( pd1 + ( i + 1 ) * bs ) ;
const uint8_t * restrict p0 = pb0 + i * 16 ;
const uint8_t * restrict p0 = pb0 + i * bs ;
const uint8_t * restrict p1 = pb1 + i * 16 ;
const uint8_t * restrict p1 = pb1 + i * bs ;
const v128_t m4b = wasm_u8x16_splat ( 0xf ) ;
const v128_t m4b = wasm_u8x16_splat ( 0xf ) ;
const v128_t s8b = wasm_i8x16_splat ( 0x8 ) ;
const v128_t s8b = wasm_i8x16_splat ( 0x8 ) ;
const v128_t v0_0 = wasm_v128_load ( p0 ) ;
const v128_t v0_0 = wasm_v128_load ( p0 ) ;
const v128_t v0_1 = wasm_v128_load ( p0 + 16 ) ;
const v128_t v0_1 = wasm_v128_load ( p0 + bs ) ;
const v128_t v1_0 = wasm_v128_load ( p1 ) ;
const v128_t v1_0 = wasm_v128_load ( p1 ) ;
const v128_t v1_1 = wasm_v128_load ( p1 + 16 ) ;
const v128_t v1_1 = wasm_v128_load ( p1 + bs ) ;
// 4-bit -> 8-bit
// 4-bit -> 8-bit
const v128_t v0_0l = wasm_v128_and ( v0_0 , m4b ) ;
const v128_t v0_0l = wasm_v128_and ( v0_0 , m4b ) ;
@ -1363,11 +1375,11 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
# else
# else
// scalar
// scalar
for ( int i = 0 ; i < nb ; i + + ) {
for ( int i = 0 ; i < nb ; i + + ) {
const float d0 = pd0 [ i ] ;
const float d0 = * ( const float * ) ( pd0 + i * bs ) ;
const float d1 = pd1 [ i ] ;
const float d1 = * ( const float * ) ( pd1 + i * bs ) ;
const uint8_t * restrict p0 = pb0 + i * QK/ 2 ;
const uint8_t * restrict p0 = pb0 + i * bs ;
const uint8_t * restrict p1 = pb1 + i * QK/ 2 ;
const uint8_t * restrict p1 = pb1 + i * bs ;
for ( int j = 0 ; j < QK / 2 ; j + + ) {
for ( int j = 0 ; j < QK / 2 ; j + + ) {
const uint8_t v0 = p0 [ j ] ;
const uint8_t v0 = p0 [ j ] ;
@ -1552,16 +1564,17 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
assert ( n % QK = = 0 ) ;
assert ( n % QK = = 0 ) ;
const int nb = n / QK ;
const int nb = n / QK ;
const size_t bs = sizeof ( float ) + QK / 2 ;
const float * restrict pd = ( const float * ) ( x ) ;
const uint8_t * restrict pd = ( const uint8_t * ) ( x + 0 * bs ) ;
const uint8_t * restrict pb = ( const uint8_t * ) ( pd + nb ) ;
const uint8_t * restrict pb = ( const uint8_t * ) ( x + 0 * bs + sizeof ( float ) ) ;
# if __ARM_NEON
# if __ARM_NEON
# if QK == 32
# if QK == 32
for ( int i = 0 ; i < nb ; + + i ) {
for ( int i = 0 ; i < nb ; + + i ) {
const float d0 = pd[ i ] * v ;
const float d0 = v* ( * ( const float * ) ( pd + i * bs ) ) ;
const uint8_t * restrict pp = pb + i * 16 ;
const uint8_t * restrict pp = pb + i * bs ;
const uint8x8_t m4b = vdup_n_u8 ( 0xf ) ;
const uint8x8_t m4b = vdup_n_u8 ( 0xf ) ;
const int8x8_t s8b = vdup_n_s8 ( 0x8 ) ;
const int8x8_t s8b = vdup_n_s8 ( 0x8 ) ;
@ -1615,9 +1628,9 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
# else
# else
// scalar
// scalar
for ( int i = 0 ; i < nb ; i + + ) {
for ( int i = 0 ; i < nb ; i + + ) {
const float d = pd [ i ] ;
const float d = * ( const float * ) ( pd + i * bs ) ;
const uint8_t * restrict pp = pb + i * QK/ 2 ;
const uint8_t * restrict pp = pb + i * bs ;
for ( int l = 0 ; l < QK ; l + = 2 ) {
for ( int l = 0 ; l < QK ; l + = 2 ) {
const uint8_t vi = pp [ l / 2 ] ;
const uint8_t vi = pp [ l / 2 ] ;