Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SIMD intrinsics for reverseBits #71

Merged
merged 4 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions cbits/bitvec_simd.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include <inttypes.h>
#include <stddef.h>

#ifdef __x86_64__
#include <immintrin.h>
#endif

#include "HsFFI.h"

HsInt _hs_bitvec_popcount(const uint32_t *src, HsInt len) {
Expand Down Expand Up @@ -80,3 +84,104 @@ void _hs_bitvec_xnor(uint8_t *dest, const uint8_t *src1, const uint8_t *src2, Hs
dest[i] = ~(src1[i] ^ src2[i]);
}
}

#ifdef __x86_64__
static void reverse_bits_sse(uint32_t *dest, const uint32_t *src, HsInt len) {
__m128i mask1l = _mm_set1_epi32(0x55555555);
__m128i mask1r = _mm_set1_epi32(0xaaaaaaaa);
__m128i mask2l = _mm_set1_epi32(0x33333333);
__m128i mask2r = _mm_set1_epi32(0xcccccccc);
__m128i mask4l = _mm_set1_epi32(0x0f0f0f0f);
__m128i mask4r = _mm_set1_epi32(0xf0f0f0f0);
__m128i mask8l = _mm_set1_epi32(0x00ff00ff);
__m128i mask8r = _mm_set1_epi32(0xff00ff00);
__m128i mask16l = _mm_set1_epi32(0x0000ffff);
__m128i mask16r = _mm_set1_epi32(0xffff0000);

size_t i = 0;
for (; i < (len & (~0x3)); i += 4) {
__m128i x = _mm_loadu_si128((const __m128i *) (src + i));

// reverse each word
x = _mm_or_si128(_mm_slli_epi32(_mm_and_si128(x, mask1l), 1), _mm_srli_epi32(_mm_and_si128(x, mask1r), 1));
x = _mm_or_si128(_mm_slli_epi32(_mm_and_si128(x, mask2l), 2), _mm_srli_epi32(_mm_and_si128(x, mask2r), 2));
x = _mm_or_si128(_mm_slli_epi32(_mm_and_si128(x, mask4l), 4), _mm_srli_epi32(_mm_and_si128(x, mask4r), 4));
x = _mm_or_si128(_mm_slli_epi32(_mm_and_si128(x, mask8l), 8), _mm_srli_epi32(_mm_and_si128(x, mask8r), 8));
x = _mm_or_si128(_mm_slli_epi32(_mm_and_si128(x, mask16l), 16), _mm_srli_epi32(_mm_and_si128(x, mask16r), 16));

// reverse order of words
x = _mm_shuffle_epi32(x, 0x1b);

_mm_storeu_si128((__m128i *) (dest + len - 4 - i), x);
}
for (; i < len; i++) {
uint32_t x = src[i];
x = ((x & 0x55555555) << 1) | ((x & 0xaaaaaaaa) >> 1);
x = ((x & 0x33333333) << 2) | ((x & 0xcccccccc) >> 2);
x = ((x & 0x0f0f0f0f) << 4) | ((x & 0xf0f0f0f0) >> 4);
x = ((x & 0x00ff00ff) << 8) | ((x & 0xff00ff00) >> 8);
x = ((x & 0x0000ffff) << 16) | ((x & 0xffff0000) >> 16);
dest[len - 1 - i] = x;
}
}

__attribute__((target("avx2")))
static void reverse_bits_avx(uint32_t *dest, const uint32_t *src, HsInt len) {
__m256i mask1l = _mm256_set1_epi32(0x55555555);
__m256i mask1r = _mm256_set1_epi32(0xaaaaaaaa);
__m256i mask2l = _mm256_set1_epi32(0x33333333);
__m256i mask2r = _mm256_set1_epi32(0xcccccccc);
__m256i mask4l = _mm256_set1_epi32(0x0f0f0f0f);
__m256i mask4r = _mm256_set1_epi32(0xf0f0f0f0);
__m256i mask8l = _mm256_set1_epi32(0x00ff00ff);
__m256i mask8r = _mm256_set1_epi32(0xff00ff00);
__m256i mask16l = _mm256_set1_epi32(0x0000ffff);
__m256i mask16r = _mm256_set1_epi32(0xffff0000);

size_t i = 0;
for (; i < (len & (~0x7)); i += 8) {
__m256i x = _mm256_loadu_si256((const __m256i *) (src + i));

// reverse each word
x = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(x, mask1l), 1), _mm256_srli_epi32(_mm256_and_si256(x, mask1r), 1));
x = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(x, mask2l), 2), _mm256_srli_epi32(_mm256_and_si256(x, mask2r), 2));
x = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(x, mask4l), 4), _mm256_srli_epi32(_mm256_and_si256(x, mask4r), 4));
x = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(x, mask8l), 8), _mm256_srli_epi32(_mm256_and_si256(x, mask8r), 8));
x = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(x, mask16l), 16), _mm256_srli_epi32(_mm256_and_si256(x, mask16r), 16));

// reverse order of words
x = _mm256_permutevar8x32_epi32(x, _mm256_setr_epi32(7, 6, 5, 4, 3, 2, 1, 0));

_mm256_storeu_si256((__m256i *) (dest + len - 8 - i), x);
}
for (; i < len; i++) {
uint32_t x = src[i];
x = ((x & 0x55555555) << 1) | ((x & 0xaaaaaaaa) >> 1);
x = ((x & 0x33333333) << 2) | ((x & 0xcccccccc) >> 2);
x = ((x & 0x0f0f0f0f) << 4) | ((x & 0xf0f0f0f0) >> 4);
x = ((x & 0x00ff00ff) << 8) | ((x & 0xff00ff00) >> 8);
x = ((x & 0x0000ffff) << 16) | ((x & 0xffff0000) >> 16);
dest[len - 1 - i] = x;
}
}
#endif

void _hs_bitvec_reverse_bits(uint32_t *dest, const uint32_t *src, HsInt len) {
#ifdef __x86_64__
if (__builtin_cpu_supports("avx2")) {
reverse_bits_avx(dest, src, len);
} else {
reverse_bits_sse(dest, src, len);
}
#else
for (size_t i = 0; i < len; i++) {
uint32_t x = src[i];
konsumlamm marked this conversation as resolved.
Show resolved Hide resolved
x = ((x & 0x55555555) << 1) | ((x & 0xaaaaaaaa) >> 1);
x = ((x & 0x33333333) << 2) | ((x & 0xcccccccc) >> 2);
x = ((x & 0x0f0f0f0f) << 4) | ((x & 0xf0f0f0f0) >> 4);
x = ((x & 0x00ff00ff) << 8) | ((x & 0xff00ff00) >> 8);
x = ((x & 0x0000ffff) << 16) | ((x & 0xffff0000) >> 16);
dest[len - 1 - i] = x;
}
#endif
}
7 changes: 7 additions & 0 deletions src/Data/Bit/Immutable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,13 @@ excludeBits is xs = runST $ do
--
-- @since 1.0.1.0
reverseBits :: U.Vector Bit -> U.Vector Bit
#ifdef UseSIMD
reverseBits (BitVec 0 len arr) | modWordSize len == 0 = runST $ do
let n = len `shiftR` 5 -- length in 32 bit words
marr <- newByteArray (n `shiftL` 2)
konsumlamm marked this conversation as resolved.
Show resolved Hide resolved
reverseBitsC marr arr n
BitVec 0 len <$> unsafeFreezeByteArray marr
#endif
reverseBits xs = runST $ do
let n = U.length xs
ys <- MU.new n
Expand Down
10 changes: 10 additions & 0 deletions src/Data/Bit/SIMD.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ module Data.Bit.SIMD
, ompNand
, ompNior
, ompXnor
, reverseBitsC
) where

import Control.Monad.ST
Expand Down Expand Up @@ -118,3 +119,12 @@ ompXnor :: MutableByteArray s -> ByteArray -> ByteArray -> Int -> ST s ()
ompXnor (MutableByteArray res#) (ByteArray arg1#) (ByteArray arg2#) (I# len#) =
unsafeIOToST (omp_xnor res# arg1# arg2# len#)
{-# INLINE ompXnor #-}

foreign import ccall unsafe "_hs_bitvec_reverse_bits"
reverse_bits :: MutableByteArray# s -> ByteArray# -> Int# -> IO ()

-- | The length is in 32 bit words.
reverseBitsC :: MutableByteArray s -> ByteArray -> Int -> ST s ()
reverseBitsC (MutableByteArray res#) (ByteArray arg#) (I# len#) =
unsafeIOToST (reverse_bits res# arg# len#)
{-# INLINE reverseBitsC #-}
2 changes: 1 addition & 1 deletion test/Tests/SetOps.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ setOpTests = testGroup "Set operations"
, testProperty "invertInPlace middle" prop_invertInPlace_middle
, testProperty "invertInPlaceLong middle" prop_invertInPlaceLong_middle

, mkGroup "reverseBits" prop_reverseBits
, localOption (QuickCheckTests 500) $ mkGroup "reverseBits" prop_reverseBits
konsumlamm marked this conversation as resolved.
Show resolved Hide resolved

, testProperty "reverseInPlace" prop_reverseInPlace
, testProperty "reverseInPlaceWords" prop_reverseInPlaceWords
Expand Down