#include #include #include #include #include #include #include #include const int N = 1 << 14; const int M = 768; // // naive implementation // void mul_mat_vec_f32_0( const float * restrict src0, const float * restrict src1, float * dst, int nrows, int ncols) { for (int i = 0; i < nrows; i++) { float sum = 0.0f; for (int j = 0; j < ncols; j++) { sum += src0[i*ncols + j]*src1[j]; } dst[i] = sum; } } void mul_mat_vec_f16_0( const __fp16 * src0, const __fp16 * src1, float * dst, int nrows, int ncols) { const int n64 = ncols & ~63; for (int r = 0; r < nrows; r++) { float sumf = 0.0; float16x8_t sum0 = vdupq_n_f16(0.0f); float16x8_t sum1 = vdupq_n_f16(0.0f); float16x8_t sum2 = vdupq_n_f16(0.0f); float16x8_t sum3 = vdupq_n_f16(0.0f); float16x8_t sum4 = vdupq_n_f16(0.0f); float16x8_t sum5 = vdupq_n_f16(0.0f); float16x8_t sum6 = vdupq_n_f16(0.0f); float16x8_t sum7 = vdupq_n_f16(0.0f); float16x8_t x0, x1, x2, x3, x4, x5, x6, x7; float16x8_t y0, y1, y2, y3, y4, y5, y6, y7; const __fp16 * restrict p0 = src0 + r*ncols; for (int i = 0; i < n64; i += 64) { x0 = vld1q_f16(p0 + i + 0 ); x1 = vld1q_f16(p0 + i + 8 ); x2 = vld1q_f16(p0 + i + 16); x3 = vld1q_f16(p0 + i + 24); x4 = vld1q_f16(p0 + i + 32); x5 = vld1q_f16(p0 + i + 40); x6 = vld1q_f16(p0 + i + 48); x7 = vld1q_f16(p0 + i + 56); y0 = vld1q_f16(src1 + i + 0 ); y1 = vld1q_f16(src1 + i + 8 ); y2 = vld1q_f16(src1 + i + 16); y3 = vld1q_f16(src1 + i + 24); y4 = vld1q_f16(src1 + i + 32); y5 = vld1q_f16(src1 + i + 40); y6 = vld1q_f16(src1 + i + 48); y7 = vld1q_f16(src1 + i + 56); sum0 = vfmaq_f16(sum0, x0, y0); sum1 = vfmaq_f16(sum1, x1, y1); sum2 = vfmaq_f16(sum2, x2, y2); sum3 = vfmaq_f16(sum3, x3, y3); sum4 = vfmaq_f16(sum4, x4, y4); sum5 = vfmaq_f16(sum5, x5, y5); sum6 = vfmaq_f16(sum6, x6, y6); sum7 = vfmaq_f16(sum7, x7, y7); } // TODO: F16 - better way to reduce this ? float16x8_t sum = vaddq_f16(sum0, sum1); sum = vaddq_f16(sum, sum2); sum = vaddq_f16(sum, sum3); sum = vaddq_f16(sum, sum4); sum = vaddq_f16(sum, sum5); sum = vaddq_f16(sum, sum6); sum = vaddq_f16(sum, sum7); sumf += sum[0] + sum[1] + sum[2] + sum[3] + sum[4] + sum[5] + sum[6] + sum[7]; for (int j = n64; j < n64; j++) { sumf += src0[r*ncols + j]*src1[j]; } dst[r] = sumf; } } uint64_t get_time_us() { struct timeval tv; gettimeofday(&tv, NULL); return tv.tv_sec * 1000000 + tv.tv_usec; } int main(int argc, const char ** argv) { float * src0 = (float *)malloc(sizeof(float)*N*M); float * src1 = (float *)malloc(sizeof(float)*M); float * dst = (float *)malloc(sizeof(float)*N); //float * src0 = (float *)(aligned_alloc(64, sizeof(float)*N*M)); //float * src1 = (float *)(aligned_alloc(64, sizeof(float)*M)); //float * dst = (float *)(aligned_alloc(64, sizeof(float)*N)); for (int i = 0; i < N*M; i++) { src0[i] = rand() / (float)RAND_MAX; } for (int i = 0; i < M; i++) { src1[i] = rand() / (float)RAND_MAX; } // convert src0 and src1 to __fp16 __fp16 * src0_fp16 = (__fp16 *)(malloc(sizeof(__fp16)*N*M)); __fp16 * src1_fp16 = (__fp16 *)(malloc(sizeof(__fp16)*M)); { const uint64_t t_start = get_time_us(); for (int i = 0; i < N*M; i++) { src0_fp16[i] = src0[i]; //printf("%f %f\n", src0[i], src0_fp16[i]); //assert(!isnan(src0_fp16[i])); } for (int i = 0; i < M; i++) { src1_fp16[i] = src1[i]; } const uint64_t t_end = get_time_us(); printf("convert time: %f ms\n", (t_end - t_start) / 1000.0); } for (int i = 0; i < 16; ++i) { printf("%f %f\n", src0[i], src0_fp16[i]); } int method = 0; if (argc > 1) { method = atoi(argv[1]); } const int nIter = 1000; const clock_t start = clock(); const uint64_t start_us = get_time_us(); double iM = 1.0/M; double sum = 0.0f; for (int i = 0; i < nIter; i++) { if (method == 0) { mul_mat_vec_f32_0(src0, src1, dst, N, M); } if (method == 1) { mul_mat_vec_f16_0(src0_fp16, src1_fp16, dst, N, M); } } for (int i = 0; i < N; i++) { sum += dst[i]*iM; } { const clock_t end = clock(); const uint64_t end_us = get_time_us(); printf("%s: elapsed ticks: %ld\n", __func__, end - start); printf("%s: elapsed us: %llu\n", __func__, end_us - start_us); } printf("%f\n", sum); free(src0); free(src1); free(dst); free(src0_fp16); free(src1_fp16); return 0; }