diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 9db99cb0f..0e4eff331 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -345,6 +345,244 @@ static void dump_packed_block_q4x4x2(const uint8_t * v, unsigned int i, size_t k GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7])); } + +static void unpack_q4_1_quants(uint8_t * qs, const block_q4_1 * x, unsigned int bi) { + static const int qk = QK4_1; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const int x0 = (x->qs[i] & 0x0F); + const int x1 = (x->qs[i] >> 4); + qs[bi * qk + i + 0] = x0; + qs[bi * qk + i + qk / 2] = x1; + } +} + +static void pack_q4_1_quants(block_q4_1 * x, const uint8_t * qs, unsigned int bi) { + static const int qk = QK4_1; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const uint8_t x0 = qs[bi * qk + i + 0]; + const uint8_t x1 = qs[bi * qk + i + qk / 2]; + x->qs[i] = x0 | (x1 << 4); + } +} + +static void repack_row_q4_1x4x2(uint8_t * y, const block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_1x4x2; + const int nb = (k + qk - 1) / qk; + + const int dblk_size = 8 * 2; + const int mblk_size = 8 * 2; + const int qblk_size = qk / 2; + const int qrow_size = k / 2; + const int drow_size = nb * dblk_size; + + uint8_t * y_q = y + 0; + uint8_t * y_d = y + qrow_size; + uint8_t * y_m = y_d + drow_size; + + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_1x4x2]; + unpack_q4_1_quants(qs, &x[i * 8 + 0], 0); + unpack_q4_1_quants(qs, &x[i * 8 + 1], 1); + unpack_q4_1_quants(qs, &x[i * 8 + 2], 2); + unpack_q4_1_quants(qs, &x[i * 8 + 3], 3); + unpack_q4_1_quants(qs, &x[i * 8 + 4], 4); + unpack_q4_1_quants(qs, &x[i * 8 + 5], 5); + unpack_q4_1_quants(qs, &x[i * 8 + 6], 6); + unpack_q4_1_quants(qs, &x[i * 8 + 7], 7); + + for (int j = 0; j < qk / 2; j++) { + uint8_t x0; + uint8_t x1; + if (j < 64) { + x0 = qs[j]; + x1 = qs[j + 64]; + } else { + x0 = qs[j + 64]; + x1 = qs[j + 128]; + } + y_q[i * qblk_size + j] = x0 | (x1 << 4); + } + + uint16_t * dst_d = (uint16_t *) (y_d + i * dblk_size); + dst_d[0] = x[i * 8 + 0].d; dst_d[1] = x[i * 8 + 1].d; + dst_d[2] = x[i * 8 + 2].d; dst_d[3] = x[i * 8 + 3].d; + dst_d[4] = x[i * 8 + 4].d; dst_d[5] = x[i * 8 + 5].d; + dst_d[6] = x[i * 8 + 6].d; dst_d[7] = x[i * 8 + 7].d; + + uint16_t * dst_m = (uint16_t *) (y_m + i * mblk_size); + dst_m[0] = x[i * 8 + 0].m; dst_m[1] = x[i * 8 + 1].m; + dst_m[2] = x[i * 8 + 2].m; dst_m[3] = x[i * 8 + 3].m; + dst_m[4] = x[i * 8 + 4].m; dst_m[5] = x[i * 8 + 5].m; + dst_m[6] = x[i * 8 + 6].m; dst_m[7] = x[i * 8 + 7].m; + } +} + +static void unrepack_row_q4_1x4x2(block_q4_1 * x, const uint8_t * y, int64_t k) { + static const int qk = QK_Q4_1x4x2; + const int nb = (k + qk - 1) / qk; + + const int dblk_size = 8 * 2; + const int mblk_size = 8 * 2; + const int qblk_size = qk / 2; + const int qrow_size = k / 2; + const int drow_size = nb * dblk_size; + + const uint8_t * y_q = y + 0; + const uint8_t * y_d = y + qrow_size; + const uint8_t * y_m = y_d + drow_size; + + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_1x4x2]; + + for (int j = 0; j < qk / 2; j++) { + const int x0 = (y_q[i * qblk_size + j] & 0x0F); + const int x1 = (y_q[i * qblk_size + j] >> 4); + if (j < 64) { + qs[j] = x0; + qs[j + 64] = x1; + } else { + qs[j + 64] = x0; + qs[j + 128] = x1; + } + } + + pack_q4_1_quants(&x[i * 8 + 0], qs, 0); + pack_q4_1_quants(&x[i * 8 + 1], qs, 1); + pack_q4_1_quants(&x[i * 8 + 2], qs, 2); + pack_q4_1_quants(&x[i * 8 + 3], qs, 3); + pack_q4_1_quants(&x[i * 8 + 4], qs, 4); + pack_q4_1_quants(&x[i * 8 + 5], qs, 5); + pack_q4_1_quants(&x[i * 8 + 6], qs, 6); + pack_q4_1_quants(&x[i * 8 + 7], qs, 7); + + const uint16_t * src_d = (const uint16_t *) (y_d + i * dblk_size); + x[i * 8 + 0].d = src_d[0]; x[i * 8 + 1].d = src_d[1]; + x[i * 8 + 2].d = src_d[2]; x[i * 8 + 3].d = src_d[3]; + x[i * 8 + 4].d = src_d[4]; x[i * 8 + 5].d = src_d[5]; + x[i * 8 + 6].d = src_d[6]; x[i * 8 + 7].d = src_d[7]; + + const uint16_t * src_m = (const uint16_t *) (y_m + i * mblk_size); + x[i * 8 + 0].m = src_m[0]; x[i * 8 + 1].m = src_m[1]; + x[i * 8 + 2].m = src_m[2]; x[i * 8 + 3].m = src_m[3]; + x[i * 8 + 4].m = src_m[4]; x[i * 8 + 5].m = src_m[5]; + x[i * 8 + 6].m = src_m[6]; x[i * 8 + 7].m = src_m[7]; + } +} + +static void init_row_q4_1x4x2(block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_1x4x2; + const int nb = (k + qk - 1) / qk; + + uint8_t qs[QK_Q4_1x4x2]; + memset(qs, 8, sizeof(qs)); + + for (int i = 0; i < nb; i++) { + pack_q4_1_quants(&x[i * 8 + 0], qs, 0); + pack_q4_1_quants(&x[i * 8 + 1], qs, 1); + pack_q4_1_quants(&x[i * 8 + 2], qs, 2); + pack_q4_1_quants(&x[i * 8 + 3], qs, 3); + pack_q4_1_quants(&x[i * 8 + 4], qs, 4); + pack_q4_1_quants(&x[i * 8 + 5], qs, 5); + pack_q4_1_quants(&x[i * 8 + 6], qs, 6); + pack_q4_1_quants(&x[i * 8 + 7], qs, 7); + + x[i * 8 + 0].d = 0; x[i * 8 + 0].m = 0; + x[i * 8 + 1].d = 0; x[i * 8 + 1].m = 0; + x[i * 8 + 2].d = 0; x[i * 8 + 2].m = 0; + x[i * 8 + 3].d = 0; x[i * 8 + 3].m = 0; + x[i * 8 + 4].d = 0; x[i * 8 + 4].m = 0; + x[i * 8 + 5].d = 0; x[i * 8 + 5].m = 0; + x[i * 8 + 6].d = 0; x[i * 8 + 6].m = 0; + x[i * 8 + 7].d = 0; x[i * 8 + 7].m = 0; + } +} + +static void repack_q4_1_q4x4x2(ggml_tensor * t, const void * data, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_1x4x2)); + size_t row_size_rp = row_size * 2; + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + memcpy(buf_pd, src, row_size); + repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + if (n_rem_bytes > 0) { + const uint8_t * src = (const uint8_t *) data + (n_full_rows * row_size); + uint8_t * dst = (uint8_t *) t->data + (n_full_rows * row_size); + memcpy(buf_pd, src, n_rem_bytes); + repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, n_rem_bytes); + } + + ggml_aligned_free(buf_rp, row_size_rp); + ggml_aligned_free(buf_pd, row_size_pd); +} + +static void repack_q4x4x2_q4_1(void * data, const ggml_tensor * t, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_1x4x2)); + size_t row_size_rp = row_size * 2; + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + memcpy(buf_rp, src, row_size); + unrepack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); + memcpy(dst, buf_pd, row_size); + } + + if (n_rem_bytes > 0) { + const uint8_t * src = (const uint8_t *) t->data + (n_full_rows * row_size); + uint8_t * dst = (uint8_t *) data + (n_full_rows * row_size); + memcpy(buf_rp, src, n_rem_bytes); + unrepack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); + memcpy(dst, buf_pd, n_rem_bytes); + } + + ggml_aligned_free(buf_rp, row_size_rp); + ggml_aligned_free(buf_pd, row_size_pd); +} + + + + + static void unpack_q4_0_quants(uint8_t * qs, const block_q4_0 * x, unsigned int bi) { static const int qk = QK4_0; @@ -369,7 +607,6 @@ static void pack_q4_0_quants(block_q4_0 * x, const uint8_t * qs, unsigned int bi static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { static const int qk = QK_Q4_0x4x2; const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers const int dblk_size = 8 * 2; // 8x __fp16 const int qblk_size = qk / 2; // int4 @@ -437,7 +674,6 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { static const int qk = QK_Q4_0x4x2; const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers const int dblk_size = 8 * 2; // 8x __fp16 const int qblk_size = qk / 2; // int4 @@ -1056,7 +1292,6 @@ static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) { static const int qk = QK_MXFP4x4x2; const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers const int eblk_size = 8 * 1; // 8x E8M0 const int qblk_size = qk / 2; // int4 @@ -1125,7 +1360,6 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) { static const int qk = QK_MXFP4x4x2; const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers const int eblk_size = 8 * 1; // 8x E8M0 const int qblk_size = qk / 2; // int4 @@ -1364,6 +1598,11 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); repack_q4_0_q4x4x2(tensor, data, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4_1_q4x4x2(tensor, data, size); + break; case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); @@ -1406,6 +1645,11 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); repack_q4x4x2_q4_0(data, tensor, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4x4x2_q4_1(data, tensor, size); + break; case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); @@ -1500,6 +1744,20 @@ static size_t ggml_backend_hexagon_buffer_type_get_alignment(ggml_backend_buffer } static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * t) { + if (t->type == GGML_TYPE_Q4_0 || t->type == GGML_TYPE_Q8_0 || t->type == GGML_TYPE_IQ4_NL || t->type == GGML_TYPE_MXFP4 || t->type == GGML_TYPE_Q4_1) { + int64_t nrows = ggml_nrows(t); + size_t row_size_pd = 0; + if (t->type == GGML_TYPE_Q4_0 || t->type == GGML_TYPE_IQ4_NL) { + row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); + } else if (t->type == GGML_TYPE_Q4_1) { + row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_1x4x2)); + } else if (t->type == GGML_TYPE_Q8_0) { + row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); + } else if (t->type == GGML_TYPE_MXFP4) { + row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); + } + return row_size_pd * nrows; + } return ggml_nbytes(t); } @@ -1651,7 +1909,7 @@ struct ggml_hexagon_opbatch { d_map.insert({t->data, ti}); uint64_t t_offset = (uint8_t *) t->data - sbuf->base; - size_t t_size = ggml_nbytes(t); + size_t t_size = ggml_backend_hexagon_buffer_type_get_alloc_size(NULL, t); htp_tensor &h = h_tens[ti]; h.bi = add_buffer(sbuf); @@ -2327,6 +2585,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: @@ -2377,6 +2636,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: @@ -3598,6 +3858,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { // Basic sanity checks to make sure definitions match static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0, "please update hexagon_type to match ggml_type"); + static_assert((unsigned int) HTP_TYPE_Q4_1 == (unsigned int) GGML_TYPE_Q4_1, + "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0, "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4, diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 3ef0bcdb2..d0643bada 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -33,6 +33,9 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, }; +static const __fp16 q4_1_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, +}; // MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value // kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 @@ -62,6 +65,8 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) { case HTP_TYPE_Q4_0: case HTP_TYPE_IQ4_NL: return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb + case HTP_TYPE_Q4_1: + return (size_t) nb * (QK_Q4_1x4x2 / 2 + HMX_X4X2_DBLK_SIZE + HMX_X4X2_DBLK_SIZE); // 160 * nb case HTP_TYPE_Q8_0: return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb case HTP_TYPE_MXFP4: @@ -331,11 +336,12 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( int start_tile, int end_tile) { const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS; - const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); + const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_Q4_1 || weight_type == HTP_TYPE_IQ4_NL); const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block; const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) : (weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) : + (weight_type == HTP_TYPE_Q4_1) ? hvx_vmem(q4_1_to_fp16_lut) : hvx_vmem(q4_0_to_fp16_lut); // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions. @@ -351,7 +357,7 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( if (kt >= n_k_tiles) { kt = 0; ct++; } // --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row --- - if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { + if (is_q4 && weight_type != HTP_TYPE_Q4_1 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4 bool upper = (sub_blk_base >= 4); @@ -441,7 +447,42 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( // --- Single-tile fallback --- __fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS; - if (is_q4) { + if (weight_type == HTP_TYPE_Q4_1) { + unsigned blk_idx = (kt * 32) / QK_Q4_1x4x2; + unsigned sub_blk = ((kt * 32) % QK_Q4_1x4x2) / 32; + bool upper = (sub_blk >= 4); + unsigned byte_off = blk_idx * (QK_Q4_1x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; + unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); + unsigned min_off = qrow_size + (k_block / QK_Q4_1x4x2) * HMX_X4X2_DBLK_SIZE + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; + unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; + unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + + HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx( + r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector m0 = Q6_Vh_vsplat_R(*(const uint32_t *)(r0 + min_off)); + v0 = Q6_Vhf_vadd_VhfVhf(v0, m0); + + HVX_Vector v1; + if (row1 < n_cols) { + v1 = dequantize_x4x2_q4_0_group_hvx( + r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); + HVX_Vector m1 = Q6_Vh_vsplat_R(*(const uint32_t *)(r1 + min_off)); + v1 = Q6_Vhf_vadd_VhfVhf(v1, m1); + } else { + v1 = Q6_V_vzero(); + } + + HVX_Vector v_interleaved = Q6_Vh_vshuff_Vh(Q6_W_vcombine_VV(v1, v0)); + Q6_vscatter_QRMVwV(q_mask64, (size_t) vtcm_dst, 2 * HMX_FP16_TILE_SIZE - 1, v_off, v_interleaved); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + kt++; t++; + } else if (is_q4) { unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; bool upper = (sub_blk >= 4); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 9d905a301..3d4fdb029 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -20,7 +20,9 @@ enum htp_data_type { HTP_TYPE_F32 = 0, HTP_TYPE_F16 = 1, HTP_TYPE_Q4_0 = 2, + HTP_TYPE_Q4_1 = 3, HTP_TYPE_Q8_0 = 8, + HTP_TYPE_Q8_1 = 9, HTP_TYPE_IQ4_NL = 20, HTP_TYPE_I32 = 26, HTP_TYPE_I64 = 27, @@ -28,7 +30,9 @@ enum htp_data_type { // types used internally for repack, dyn.quant, etc HTP_TYPE_Q4_0x4x2 = 200, + HTP_TYPE_Q4_1x4x2, HTP_TYPE_Q8_0x4x2, + HTP_TYPE_Q8_1x4x2, HTP_TYPE_MXFP4x4x2, HTP_TYPE_INVALID @@ -36,7 +40,9 @@ enum htp_data_type { // Constats for internal types #define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) +#define QK_Q4_1x4x2 256 // 4x Q4_1 blocks packed with next 4x Q4_1 blocks #define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks +#define QK_Q8_1x4x2 256 // 4x Q8_1 blocks concat with next 4x Q8_1 blocks #define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 46fc5602d..b883763d5 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -40,6 +40,8 @@ struct htp_matmul_context { const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1); + void (*quantize_row_f32)(float * restrict x, uint8_t * restrict y, uint32_t k); + // Precomputed values uint32_t src0_nrows_per_thread; uint32_t src1_nrows_per_thread; @@ -408,11 +410,11 @@ static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_dblk_size = 8 * 4 * 2; // 8x __fp16 const uint32_t x_qblk_size = qk / 2; // int4 const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 2; // 8x __fp16 const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) @@ -486,11 +488,11 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_dblk_size = 8 * 4 * 2; // 8x __fp16 const uint32_t x_qblk_size = qk / 2; // int4 const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 2; // 8x __fp16 const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) @@ -581,11 +583,11 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_dblk_size = 8 * 4 * 2; // 8x __fp16 const uint32_t x_qblk_size = qk / 2; // int4 const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 2; // 8x __fp16 const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) @@ -785,11 +787,11 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_dblk_size = 8 * 4 * 2; // 8x __fp16 const uint32_t x_qblk_size = qk; // int8 const uint32_t x_qrow_size = n; // int8 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 2; // 8x __fp16 const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) @@ -880,11 +882,11 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * const uint32_t qk = QK_Q8_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_dblk_size = 8 * 4 * 2; // 8x __fp16 const uint32_t x_qblk_size = qk; // int8 const uint32_t x_qrow_size = n; // int8 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 2; // 8x __fp16 const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) @@ -1013,11 +1015,11 @@ static void vec_dot_iq4nlx4x2_q8x4x2_1x1(const int n, const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_dblk_size = 8 * 4 * 2; // 8x __fp16 const uint32_t x_qblk_size = qk / 2; // int4 const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 2; // 8x __fp16 const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) @@ -1086,11 +1088,11 @@ static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n, const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_dblk_size = 8 * 4 * 2; // 8x __fp16 const uint32_t x_qblk_size = qk / 2; // int4 const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 2; // 8x __fp16 const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) @@ -1405,7 +1407,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint32_t x_qblk_size = qk / 2; // fp4 const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 2; // 8x __fp16 const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) @@ -1536,7 +1538,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float const uint32_t x_qblk_size = qk / 2; // fp4 const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 2; // 8x __fp16 const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) @@ -2456,6 +2458,7 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes + // Combine and convert to fp16 HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); @@ -2464,10 +2467,8 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); - HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); - HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + HVX_Vector vd01_hf = Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008))); // 1.0 / 127.0 + HVX_Vector vd23_hf = Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008))); // 1.0 / 127.0 hvx_vec_store_u(y_d + 0, 2, vd01_hf); HVX_Vector rotated_vd_hf = Q6_V_vror_VR(vd01_hf, 64); @@ -2480,8 +2481,9 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric // Divide input by the scale HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); - vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); - vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + vx01_hf = Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + // Convert to int8 HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); @@ -2574,6 +2576,503 @@ static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restric } // Overrides input x + + +static inline void quantize_block_f32_q8_1x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d, uint8_t * restrict y_s) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_q % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + // Read sums BEFORE we potentially overwrite x via t_d or t_s + HVX_Vector sum_sf_0 = hvx_vec_reduce_sum_f32(vx[0]); + HVX_Vector sum_sf_1 = hvx_vec_reduce_sum_f32(vx[1]); + HVX_Vector sum_sf_2 = hvx_vec_reduce_sum_f32(vx[2]); + HVX_Vector sum_sf_3 = hvx_vec_reduce_sum_f32(vx[3]); + + // Use reduce max fp32 to find max(abs(e)) first + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + + // Load and convert into QF32 + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements + + // Convert to QF32 + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes + + + // Combine and convert to fp16 + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + // Convert into fp16 + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vd01_hf = Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008))); // 1.0 / 127.0 + HVX_Vector vd23_hf = Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008))); // 1.0 / 127.0 + + hvx_vec_store_u(y_d + 0, 2, vd01_hf); + HVX_Vector rotated_vd_hf = Q6_V_vror_VR(vd01_hf, 64); + hvx_vec_store_u(y_d + 2, 2, rotated_vd_hf); + + hvx_vec_store_u(y_d + 4, 2, vd23_hf); + rotated_vd_hf = Q6_V_vror_VR(vd23_hf, 64); + hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf); + + // Divide input by the scale + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + + // Convert back to integer types + HVX_VectorPair vi01_wp = Q6_Ww_vcvt_VhfR(vx01_hf, 0); + HVX_Vector vi01_b = Q6_Vb_vshuffe_VbVb(zero, Q6_Vb_vdeal_Vb(Q6_Vb_vshuffe_VbVb(Q6_V_hi_W(vi01_wp), Q6_V_lo_W(vi01_wp)))); + HVX_VectorPair vi23_wp = Q6_Ww_vcvt_VhfR(vx23_hf, 0); + HVX_Vector vi23_b = Q6_Vb_vshuffe_VbVb(zero, Q6_Vb_vdeal_Vb(Q6_Vb_vshuffe_VbVb(Q6_V_hi_W(vi23_wp), Q6_V_lo_W(vi23_wp)))); + + *(HVX_UVector *) (y_q + 0) = vi01_b; + *(HVX_UVector *) (y_q + 64) = vi23_b; + + // Now we need to store them as fp16 + HVX_Vector sum0_qf = Q6_Vqf32_vsub_VsfVsf(sum_sf_0, zero); + HVX_Vector sum1_qf = Q6_Vqf32_vsub_VsfVsf(sum_sf_1, zero); + HVX_Vector sum2_qf = Q6_Vqf32_vsub_VsfVsf(sum_sf_2, zero); + HVX_Vector sum3_qf = Q6_Vqf32_vsub_VsfVsf(sum_sf_3, zero); + + HVX_Vector sum01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(sum1_qf, sum0_qf))); + HVX_Vector sum23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(sum3_qf, sum2_qf))); + + hvx_vec_store_u(y_s + 0, 2, sum01_hf); + HVX_Vector rotated_sum_hf = Q6_V_vror_VR(sum01_hf, 64); + hvx_vec_store_u(y_s + 2, 2, rotated_sum_hf); + + hvx_vec_store_u(y_s + 4, 2, sum23_hf); + rotated_sum_hf = Q6_V_vror_VR(sum23_hf, 64); + hvx_vec_store_u(y_s + 6, 2, rotated_sum_hf); +} + + +static inline void quantize_block_f32_q8_1x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d, uint8_t * restrict y_s) { + quantize_block_f32_q8_1x1(x + 0 * 128, y_q + 0 * 128, y_d + 0 * 8, y_s + 0 * 8); + quantize_block_f32_q8_1x1(x + 1 * 128, y_q + 1 * 128, y_d + 1 * 8, y_s + 1 * 8); +} + +static inline void quantize_block_f32_q8_1x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d, uint8_t * restrict y_s) { + quantize_block_f32_q8_1x2(x + 0 * 256, y_q + 0 * 256, y_d + 0 * 16, y_s + 0 * 16); + quantize_block_f32_q8_1x2(x + 1 * 256, y_q + 1 * 256, y_d + 1 * 16, y_s + 1 * 16); +} + + +static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t qk = QK_Q8_1x4x2; + const uint32_t nb = (k + qk - 1) / qk; + + const uint32_t qrow_size = k; // int8 + + const uint32_t dblk_size = 8 * 2; // 8x __fp16 + const uint32_t sblk_size = 8 * 2; // 8x __fp16 + const uint32_t qblk_size = QK_Q8_1x4x2; // int8 + const uint32_t drow_size = nb * dblk_size; + + uint8_t * restrict y_q = y + 0; + uint8_t * restrict y_d = y + qrow_size; + uint8_t * restrict y_s = y_d + drow_size; + + // Use stack buffers for temp scales and sums to avoid aliasing `x` + // Hexagon has a large stack; up to 8192 bytes here is safe (max K ~32k -> 2k bytes) + uint8_t t_d_buf[8192]; + uint8_t t_s_buf[8192]; + uint8_t * restrict t_d = t_d_buf; + uint8_t * restrict t_s = t_s_buf; + + for (uint32_t i = 0; i < nb; i++) { +#if FP32_QUANTIZE_GROUP_SIZE == 32 + quantize_block_f32_q8_1x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2, t_s + (i*2 + 0) * sblk_size/2); + quantize_block_f32_q8_1x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2, t_s + (i*2 + 1) * sblk_size/2); +#elif FP32_QUANTIZE_GROUP_SIZE == 64 + quantize_block_f32_q8_1x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2, t_s + (i*2 + 0) * sblk_size/2); + quantize_block_f32_q8_1x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2, t_s + (i*2 + 1) * sblk_size/2); +#elif FP32_QUANTIZE_GROUP_SIZE == 128 + quantize_block_f32_q8_1x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2, t_s + (i*2 + 0) * sblk_size/2); + quantize_block_f32_q8_1x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2, t_s + (i*2 + 1) * sblk_size/2); +#else +#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128" +#endif + } + + // copy the precalculated scales and sums from temp buffer into final interleaved dst locations + for (uint32_t i = 0; i < nb; i++) { + for (int j = 0; j < dblk_size; j += 2) { + y_d[i * dblk_size + j + 0] = t_d[i * dblk_size + j + 0]; + y_d[i * dblk_size + j + 1] = t_d[i * dblk_size + j + 1]; + } + for (int j = 0; j < sblk_size; j += 2) { + y_s[i * sblk_size + j + 0] = t_s[i * sblk_size + j + 0]; + y_s[i * sblk_size + j + 1] = t_s[i * sblk_size + j + 1]; + } + } +} + + + +static void vec_dot_q4_1x4x2_q8_1x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_1x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; + const uint32_t x_mblk_size = 8 * 4 * 2; + const uint32_t x_qblk_size = qk / 2; + const uint32_t x_qrow_size = n / 2; + const uint32_t x_drow_size = (n / (QK_Q4_1x4x2)) * 8 * 2; + + const uint32_t y_dblk_size = 8 * 4 * 2; + const uint32_t y_sblk_size = 8 * 4 * 2; + const uint32_t y_qblk_size = qk; + const uint32_t y_qrow_size = n; + const uint32_t y_drow_size = (n / (QK_Q8_1x4x2)) * 8 * 2; + // Note: The quantize_row dynamically sets y_dblk_size to 8 * 2 and y_sblk_size to 8 * 2. + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); + const uint8_t * restrict r0_x_m = ((const uint8_t *) vx0 + x_qrow_size + x_drow_size); + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); + const uint8_t * restrict y_s = ((const uint8_t *) vy0 + y_qrow_size + y_drow_size); + + HVX_Vector r0_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_s + i * y_sblk_size)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_mblk_size)); + + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + r0_fa = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms)); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_s + i * y_sblk_size)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_mblk_size)); + + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + r0_fa = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms)); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); // The reduced vector has the same sum across all elements +} + +static void vec_dot_q4_1x4x2_q8_1x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_1x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; + const uint32_t x_mblk_size = 8 * 4 * 2; + const uint32_t x_qblk_size = qk / 2; + const uint32_t x_qrow_size = n / 2; + const uint32_t x_drow_size = (n / (QK_Q4_1x4x2)) * 8 * 2; + + const uint32_t y_dblk_size = 8 * 4 * 2; + const uint32_t y_sblk_size = 8 * 4 * 2; + const uint32_t y_qblk_size = qk; + const uint32_t y_qrow_size = n; + const uint32_t y_drow_size = (n / (QK_Q8_1x4x2)) * 8 * 2; + // Note: The quantize_row dynamically sets y_dblk_size to 8 * 2 and y_sblk_size to 8 * 2. + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); + const uint8_t * restrict r0_x_m = ((const uint8_t *) vx0 + x_qrow_size + x_drow_size); + + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1 + 0); + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1 + x_qrow_size); + const uint8_t * restrict r1_x_m = ((const uint8_t *) vx1 + x_qrow_size + x_drow_size); + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); + const uint8_t * restrict y_s = ((const uint8_t *) vy0 + y_qrow_size + y_drow_size); + + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_s + i * y_sblk_size)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_mblk_size)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_m + i * x_mblk_size)); + + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); + + r0_fa = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms)); + r1_fa = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms)); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_s + i * y_sblk_size)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_mblk_size)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_m + i * x_mblk_size)); + + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); + + r0_fa = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms)); + r1_fa = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms)); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_q4_1x4x2_q8_1x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q4_1x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; + const uint32_t x_mblk_size = 8 * 4 * 2; + const uint32_t x_qblk_size = qk / 2; + const uint32_t x_qrow_size = n / 2; + const uint32_t x_drow_size = (n / (QK_Q4_1x4x2)) * 8 * 2; + + const uint32_t y_dblk_size = 8 * 4 * 2; + const uint32_t y_sblk_size = 8 * 4 * 2; + const uint32_t y_qblk_size = qk; + const uint32_t y_qrow_size = n; + const uint32_t y_drow_size = (n / (QK_Q8_1x4x2)) * 8 * 2; + // Note: The quantize_row dynamically sets y_dblk_size to 8 * 2 and y_sblk_size to 8 * 2. + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); + const uint8_t * restrict r0_x_m = ((const uint8_t *) vx0 + x_qrow_size + x_drow_size); + + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1 + 0); + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1 + x_qrow_size); + const uint8_t * restrict r1_x_m = ((const uint8_t *) vx1 + x_qrow_size + x_drow_size); + + const uint8_t * restrict r0_y_q = ((const uint8_t *) vy0 + 0); + const uint8_t * restrict r0_y_d = ((const uint8_t *) vy0 + y_qrow_size); + const uint8_t * restrict r0_y_s = ((const uint8_t *) vy0 + y_qrow_size + y_drow_size); + + const uint8_t * restrict r1_y_q = ((const uint8_t *) vy1 + 0); + const uint8_t * restrict r1_y_d = ((const uint8_t *) vy1 + y_qrow_size); + const uint8_t * restrict r1_y_s = ((const uint8_t *) vy1 + y_qrow_size + y_drow_size); + + HVX_Vector r00_sum = Q6_V_vzero(); + HVX_Vector r10_sum = Q6_V_vzero(); + HVX_Vector r01_sum = Q6_V_vzero(); + HVX_Vector r11_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 r0_y_qv = hvx_vec_load_q8x4x8_full(r0_y_q + i * y_qblk_size); + HVX_Vector_x8 r1_y_qv = hvx_vec_load_q8x4x8_full(r1_y_q + i * y_qblk_size); + + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r00_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, r0_y_qv)); + HVX_Vector r10_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, r0_y_qv)); + HVX_Vector r01_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, r1_y_qv)); + HVX_Vector r11_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, r1_y_qv)); + + HVX_Vector r0_y_dv = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_y_d + i * y_dblk_size)); + HVX_Vector r1_y_dv = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_y_d + i * y_dblk_size)); + + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r00_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, r0_y_dv))); + HVX_Vector r10_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, r0_y_dv))); + HVX_Vector r01_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, r1_y_dv))); + HVX_Vector r11_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, r1_y_dv))); + + HVX_Vector r00_fa = Q6_Vqf32_vmpy_VsfVsf(r00_ia, r00_dd); + HVX_Vector r10_fa = Q6_Vqf32_vmpy_VsfVsf(r10_ia, r10_dd); + HVX_Vector r01_fa = Q6_Vqf32_vmpy_VsfVsf(r01_ia, r01_dd); + HVX_Vector r11_fa = Q6_Vqf32_vmpy_VsfVsf(r11_ia, r11_dd); + + HVX_Vector r0_y_sv = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_y_s + i * y_sblk_size)); + HVX_Vector r1_y_sv = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_y_s + i * y_sblk_size)); + + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_mblk_size)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_m + i * x_mblk_size)); + + HVX_Vector r00_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, r0_y_sv))); + HVX_Vector r10_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, r0_y_sv))); + HVX_Vector r01_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, r1_y_sv))); + HVX_Vector r11_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, r1_y_sv))); + + r00_fa = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r00_fa, r00_ms)); + r10_fa = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r10_fa, r10_ms)); + r01_fa = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r01_fa, r01_ms)); + r11_fa = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r11_fa, r11_ms)); + + r00_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r00_fa, r00_sum)); + r10_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r10_fa, r10_sum)); + r01_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r01_fa, r01_sum)); + r11_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r11_fa, r11_sum)); + } + + if (nloe) { + HVX_Vector_x8 r0_y_qv = hvx_vec_load_q8x4x8_partial(r0_y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r1_y_qv = hvx_vec_load_q8x4x8_partial(r1_y_q + i * y_qblk_size, nloe); + + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r00_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, r0_y_qv, nloe)); + HVX_Vector r10_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, r0_y_qv, nloe)); + HVX_Vector r01_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, r1_y_qv, nloe)); + HVX_Vector r11_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, r1_y_qv, nloe)); + + HVX_Vector r0_y_dv = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_y_d + i * y_dblk_size)); + HVX_Vector r1_y_dv = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_y_d + i * y_dblk_size)); + + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r00_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, r0_y_dv))); + HVX_Vector r10_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, r0_y_dv))); + HVX_Vector r01_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, r1_y_dv))); + HVX_Vector r11_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, r1_y_dv))); + + HVX_Vector r00_fa = Q6_Vqf32_vmpy_VsfVsf(r00_ia, r00_dd); + HVX_Vector r10_fa = Q6_Vqf32_vmpy_VsfVsf(r10_ia, r10_dd); + HVX_Vector r01_fa = Q6_Vqf32_vmpy_VsfVsf(r01_ia, r01_dd); + HVX_Vector r11_fa = Q6_Vqf32_vmpy_VsfVsf(r11_ia, r11_dd); + + HVX_Vector r0_y_sv = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_y_s + i * y_sblk_size)); + HVX_Vector r1_y_sv = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_y_s + i * y_sblk_size)); + + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_m + i * x_mblk_size)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_m + i * x_mblk_size)); + + HVX_Vector r00_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, r0_y_sv))); + HVX_Vector r10_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, r0_y_sv))); + HVX_Vector r01_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, r1_y_sv))); + HVX_Vector r11_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, r1_y_sv))); + + r00_fa = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r00_fa, r00_ms)); + r10_fa = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r10_fa, r10_ms)); + r01_fa = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r01_fa, r01_ms)); + r11_fa = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r11_fa, r11_ms)); + + r00_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r00_fa, r00_sum)); + r10_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r10_fa, r10_sum)); + r01_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r01_fa, r01_sum)); + r11_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r11_fa, r11_sum)); + } + + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r00_sum, r10_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r01_sum, r11_sum); + hvx_vec_store_u(s0, 8, r0_r1_c0_sum); + hvx_vec_store_u(s1, 8, r0_r1_c1_sum); +} + static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { assert(k % 32 == 0); const uint32_t qk = QK_Q8_0x4x2; @@ -2645,7 +3144,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) hvx_copy_f32_aa(tmp_data, src_data, ne0); // FARF(HIGH, "quantize-q8x4-row: %u\n", i); - quantize_row_f32_q8x4x2((float *) tmp_data, dst_data, ne0); + mmctx->quantize_row_f32((float *) tmp_data, dst_data, ne0); dst_data += dst_row_size; src_data += src_row_size; } @@ -2751,24 +3250,35 @@ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_t mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2; + mmctx->quantize_row_f32 = quantize_row_f32_q8x4x2; + return 0; + case HTP_TYPE_Q4_1: + mmctx->type = "q4_1x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q4_1x4x2_q8_1x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q4_1x4x2_q8_1x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q4_1x4x2_q8_1x4x2_2x2; + mmctx->quantize_row_f32 = quantize_row_f32_q8_1x4x2; return 0; case HTP_TYPE_Q8_0: mmctx->type = "q8x4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; + mmctx->quantize_row_f32 = quantize_row_f32_q8x4x2; return 0; case HTP_TYPE_IQ4_NL: mmctx->type = "iq4nlx4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2; + mmctx->quantize_row_f32 = quantize_row_f32_q8x4x2; return 0; case HTP_TYPE_MXFP4: mmctx->type = "mxfp4x4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2; + mmctx->quantize_row_f32 = quantize_row_f32_q8x4x2; return 0; default: return -1;