Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions crates/core_arch/src/x86/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1813,14 +1813,20 @@ pub const fn _mm256_inserti128_si256<const IMM1: i32>(a: __m256i, b: __m128i) ->
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpmaddwd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm256_madd_epi16(a: __m256i, b: __m256i) -> __m256i {
unsafe {
let r: i32x16 = simd_mul(simd_cast(a.as_i16x16()), simd_cast(b.as_i16x16()));
let even: i32x8 = simd_shuffle!(r, r, [0, 2, 4, 6, 8, 10, 12, 14]);
let odd: i32x8 = simd_shuffle!(r, r, [1, 3, 5, 7, 9, 11, 13, 15]);
simd_add(even, odd).as_m256i()
}
pub fn _mm256_madd_epi16(a: __m256i, b: __m256i) -> __m256i {
// It's a trick used in the Adler-32 algorithm to perform a widening addition.
//
// ```rust
// #[target_feature(enable = "avx2")]
// unsafe fn widening_add(mad: __m256i) -> __m256i {
// _mm256_madd_epi16(mad, _mm256_set1_epi16(1))
// }
// ```
//
// If we implement this using generic vector intrinsics, the optimizer
// will eliminate this pattern, and `vpmaddwd` will no longer be emitted.
// For this reason, we use x86 intrinsics.
unsafe { transmute(pmaddwd(a.as_i16x16(), b.as_i16x16())) }
}

/// Vertically multiplies each unsigned 8-bit integer from `a` with the
Expand Down Expand Up @@ -3789,6 +3795,8 @@ unsafe extern "C" {
fn phaddsw(a: i16x16, b: i16x16) -> i16x16;
#[link_name = "llvm.x86.avx2.phsub.sw"]
fn phsubsw(a: i16x16, b: i16x16) -> i16x16;
#[link_name = "llvm.x86.avx2.pmadd.wd"]
fn pmaddwd(a: i16x16, b: i16x16) -> i32x8;
#[link_name = "llvm.x86.avx2.pmadd.ub.sw"]
fn pmaddubsw(a: u8x32, b: i8x32) -> i16x16;
#[link_name = "llvm.x86.avx2.mpsadbw"]
Expand Down Expand Up @@ -4637,7 +4645,7 @@ mod tests {
}

#[simd_test(enable = "avx2")]
const fn test_mm256_madd_epi16() {
fn test_mm256_madd_epi16() {
let a = _mm256_set1_epi16(2);
let b = _mm256_set1_epi16(4);
let r = _mm256_madd_epi16(a, b);
Expand Down
64 changes: 29 additions & 35 deletions crates/core_arch/src/x86/avx512bw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6321,22 +6321,20 @@ pub const unsafe fn _mm_mask_storeu_epi8(mem_addr: *mut i8, mask: __mmask16, a:
#[target_feature(enable = "avx512bw")]
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpmaddwd))]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm512_madd_epi16(a: __m512i, b: __m512i) -> __m512i {
unsafe {
let r: i32x32 = simd_mul(simd_cast(a.as_i16x32()), simd_cast(b.as_i16x32()));
let even: i32x16 = simd_shuffle!(
r,
r,
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]
);
let odd: i32x16 = simd_shuffle!(
r,
r,
[1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31]
);
simd_add(even, odd).as_m512i()
}
pub fn _mm512_madd_epi16(a: __m512i, b: __m512i) -> __m512i {
// It's a trick used in the Adler-32 algorithm to perform a widening addition.
//
// ```rust
// #[target_feature(enable = "avx512bw")]
// unsafe fn widening_add(mad: __m512i) -> __m512i {
// _mm512_madd_epi16(mad, _mm512_set1_epi16(1))
// }
// ```
//
// If we implement this using generic vector intrinsics, the optimizer
// will eliminate this pattern, and `vpmaddwd` will no longer be emitted.
// For this reason, we use x86 intrinsics.
unsafe { transmute(vpmaddwd(a.as_i16x32(), b.as_i16x32())) }
}

/// Multiply packed signed 16-bit integers in a and b, producing intermediate signed 32-bit integers. Horizontally add adjacent pairs of intermediate 32-bit integers, and pack the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand All @@ -6346,8 +6344,7 @@ pub const fn _mm512_madd_epi16(a: __m512i, b: __m512i) -> __m512i {
#[target_feature(enable = "avx512bw")]
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpmaddwd))]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm512_mask_madd_epi16(src: __m512i, k: __mmask16, a: __m512i, b: __m512i) -> __m512i {
pub fn _mm512_mask_madd_epi16(src: __m512i, k: __mmask16, a: __m512i, b: __m512i) -> __m512i {
unsafe {
let madd = _mm512_madd_epi16(a, b).as_i32x16();
transmute(simd_select_bitmask(k, madd, src.as_i32x16()))
Expand All @@ -6361,8 +6358,7 @@ pub const fn _mm512_mask_madd_epi16(src: __m512i, k: __mmask16, a: __m512i, b: _
#[target_feature(enable = "avx512bw")]
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpmaddwd))]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm512_maskz_madd_epi16(k: __mmask16, a: __m512i, b: __m512i) -> __m512i {
pub fn _mm512_maskz_madd_epi16(k: __mmask16, a: __m512i, b: __m512i) -> __m512i {
unsafe {
let madd = _mm512_madd_epi16(a, b).as_i32x16();
transmute(simd_select_bitmask(k, madd, i32x16::ZERO))
Expand All @@ -6376,8 +6372,7 @@ pub const fn _mm512_maskz_madd_epi16(k: __mmask16, a: __m512i, b: __m512i) -> __
#[target_feature(enable = "avx512bw,avx512vl")]
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpmaddwd))]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm256_mask_madd_epi16(src: __m256i, k: __mmask8, a: __m256i, b: __m256i) -> __m256i {
pub fn _mm256_mask_madd_epi16(src: __m256i, k: __mmask8, a: __m256i, b: __m256i) -> __m256i {
unsafe {
let madd = _mm256_madd_epi16(a, b).as_i32x8();
transmute(simd_select_bitmask(k, madd, src.as_i32x8()))
Expand All @@ -6391,8 +6386,7 @@ pub const fn _mm256_mask_madd_epi16(src: __m256i, k: __mmask8, a: __m256i, b: __
#[target_feature(enable = "avx512bw,avx512vl")]
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpmaddwd))]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm256_maskz_madd_epi16(k: __mmask8, a: __m256i, b: __m256i) -> __m256i {
pub fn _mm256_maskz_madd_epi16(k: __mmask8, a: __m256i, b: __m256i) -> __m256i {
unsafe {
let madd = _mm256_madd_epi16(a, b).as_i32x8();
transmute(simd_select_bitmask(k, madd, i32x8::ZERO))
Expand All @@ -6406,8 +6400,7 @@ pub const fn _mm256_maskz_madd_epi16(k: __mmask8, a: __m256i, b: __m256i) -> __m
#[target_feature(enable = "avx512bw,avx512vl")]
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpmaddwd))]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm_mask_madd_epi16(src: __m128i, k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
pub fn _mm_mask_madd_epi16(src: __m128i, k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
unsafe {
let madd = _mm_madd_epi16(a, b).as_i32x4();
transmute(simd_select_bitmask(k, madd, src.as_i32x4()))
Expand All @@ -6421,8 +6414,7 @@ pub const fn _mm_mask_madd_epi16(src: __m128i, k: __mmask8, a: __m128i, b: __m12
#[target_feature(enable = "avx512bw,avx512vl")]
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpmaddwd))]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm_maskz_madd_epi16(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
pub fn _mm_maskz_madd_epi16(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
unsafe {
let madd = _mm_madd_epi16(a, b).as_i32x4();
transmute(simd_select_bitmask(k, madd, i32x4::ZERO))
Expand Down Expand Up @@ -12582,6 +12574,8 @@ unsafe extern "C" {
#[link_name = "llvm.x86.avx512.pmul.hr.sw.512"]
fn vpmulhrsw(a: i16x32, b: i16x32) -> i16x32;

#[link_name = "llvm.x86.avx512.pmaddw.d.512"]
fn vpmaddwd(a: i16x32, b: i16x32) -> i32x16;
#[link_name = "llvm.x86.avx512.pmaddubs.w.512"]
fn vpmaddubsw(a: u8x64, b: i8x64) -> i16x32;

Expand Down Expand Up @@ -17486,7 +17480,7 @@ mod tests {
}

#[simd_test(enable = "avx512bw")]
const fn test_mm512_madd_epi16() {
fn test_mm512_madd_epi16() {
let a = _mm512_set1_epi16(1);
let b = _mm512_set1_epi16(1);
let r = _mm512_madd_epi16(a, b);
Expand All @@ -17495,7 +17489,7 @@ mod tests {
}

#[simd_test(enable = "avx512bw")]
const fn test_mm512_mask_madd_epi16() {
fn test_mm512_mask_madd_epi16() {
let a = _mm512_set1_epi16(1);
let b = _mm512_set1_epi16(1);
let r = _mm512_mask_madd_epi16(a, 0, a, b);
Expand Down Expand Up @@ -17523,7 +17517,7 @@ mod tests {
}

#[simd_test(enable = "avx512bw")]
const fn test_mm512_maskz_madd_epi16() {
fn test_mm512_maskz_madd_epi16() {
let a = _mm512_set1_epi16(1);
let b = _mm512_set1_epi16(1);
let r = _mm512_maskz_madd_epi16(0, a, b);
Expand All @@ -17534,7 +17528,7 @@ mod tests {
}

#[simd_test(enable = "avx512bw,avx512vl")]
const fn test_mm256_mask_madd_epi16() {
fn test_mm256_mask_madd_epi16() {
let a = _mm256_set1_epi16(1);
let b = _mm256_set1_epi16(1);
let r = _mm256_mask_madd_epi16(a, 0, a, b);
Expand All @@ -17554,7 +17548,7 @@ mod tests {
}

#[simd_test(enable = "avx512bw,avx512vl")]
const fn test_mm256_maskz_madd_epi16() {
fn test_mm256_maskz_madd_epi16() {
let a = _mm256_set1_epi16(1);
let b = _mm256_set1_epi16(1);
let r = _mm256_maskz_madd_epi16(0, a, b);
Expand All @@ -17565,7 +17559,7 @@ mod tests {
}

#[simd_test(enable = "avx512bw,avx512vl")]
const fn test_mm_mask_madd_epi16() {
fn test_mm_mask_madd_epi16() {
let a = _mm_set1_epi16(1);
let b = _mm_set1_epi16(1);
let r = _mm_mask_madd_epi16(a, 0, a, b);
Expand All @@ -17576,7 +17570,7 @@ mod tests {
}

#[simd_test(enable = "avx512bw,avx512vl")]
const fn test_mm_maskz_madd_epi16() {
fn test_mm_maskz_madd_epi16() {
let a = _mm_set1_epi16(1);
let b = _mm_set1_epi16(1);
let r = _mm_maskz_madd_epi16(0, a, b);
Expand Down
26 changes: 17 additions & 9 deletions crates/core_arch/src/x86/sse2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,20 @@ pub const fn _mm_avg_epu16(a: __m128i, b: __m128i) -> __m128i {
#[target_feature(enable = "sse2")]
#[cfg_attr(test, assert_instr(pmaddwd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm_madd_epi16(a: __m128i, b: __m128i) -> __m128i {
unsafe {
let r: i32x8 = simd_mul(simd_cast(a.as_i16x8()), simd_cast(b.as_i16x8()));
let even: i32x4 = simd_shuffle!(r, r, [0, 2, 4, 6]);
let odd: i32x4 = simd_shuffle!(r, r, [1, 3, 5, 7]);
simd_add(even, odd).as_m128i()
}
pub fn _mm_madd_epi16(a: __m128i, b: __m128i) -> __m128i {
// It's a trick used in the Adler-32 algorithm to perform a widening addition.
//
// ```rust
// #[target_feature(enable = "sse2")]
// unsafe fn widening_add(mad: __m128i) -> __m128i {
// _mm_madd_epi16(mad, _mm_set1_epi16(1))
// }
// ```
//
// If we implement this using generic vector intrinsics, the optimizer
// will eliminate this pattern, and `pmaddwd` will no longer be emitted.
// For this reason, we use x86 intrinsics.
unsafe { transmute(pmaddwd(a.as_i16x8(), b.as_i16x8())) }
}
Comment on lines -213 to 227
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could consider using https://doc.rust-lang.org/std/intrinsics/fn.const_eval_select.html so we don't loose all of the const stuff. Up to @sayantn though, I don't have full context on what we'd like to be const fn currently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#[target_feature] functions do not implement the Fn traits, while const_eval_select restricts FnOnce. So this does not seem feasible.


/// Compares packed 16-bit integers in `a` and `b`, and returns the packed
Expand Down Expand Up @@ -3187,6 +3193,8 @@ unsafe extern "C" {
fn lfence();
#[link_name = "llvm.x86.sse2.mfence"]
fn mfence();
#[link_name = "llvm.x86.sse2.pmadd.wd"]
fn pmaddwd(a: i16x8, b: i16x8) -> i32x4;
#[link_name = "llvm.x86.sse2.psad.bw"]
fn psadbw(a: u8x16, b: u8x16) -> u64x2;
#[link_name = "llvm.x86.sse2.psll.w"]
Expand Down Expand Up @@ -3467,7 +3475,7 @@ mod tests {
}

#[simd_test(enable = "sse2")]
const fn test_mm_madd_epi16() {
fn test_mm_madd_epi16() {
let a = _mm_setr_epi16(1, 2, 3, 4, 5, 6, 7, 8);
let b = _mm_setr_epi16(9, 10, 11, 12, 13, 14, 15, 16);
let r = _mm_madd_epi16(a, b);
Expand Down