Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement selectBits & excludeBits in C #82

Merged
merged 3 commits into from
Aug 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions cbits/bitvec_simd.c
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,98 @@ HsInt _hs_bitvec_nth_bit_index(const HsWord *src, HsInt len, HsBool bit, HsInt n
}
return -1;
}


#ifdef __x86_64__
__attribute__((target("popcnt,bmi2")))
static HsInt select_bits_pext(uint64_t *dest, const uint64_t *src, const uint64_t *mask, HsInt len, HsBool exclude) {
uint64_t bit_mask;
if (exclude) {
bit_mask = -1;
} else {
bit_mask = 0;
}
HsInt off = 0; // offset in bits into `dest`
for (size_t i = 0; i < len; i++) {
uint64_t x = src[i];
uint64_t m = mask[i] ^ bit_mask;
HsInt count = _mm_popcnt_u64(m);
uint64_t y = _pext_u64(x, m);
HsInt off_words = off >> 6;
HsInt off_bits = off & 0x3f;
if (off_bits == 0) {
dest[off_words] = y;
} else {
dest[off_words] |= y << off_bits;
dest[off_words + 1] = y >> (64 - off_bits);
}
off += count;
}
return off;
}
#endif

HsInt _hs_bitvec_select_bits(HsWord *dest, const HsWord *src, const HsWord *mask, HsInt len, HsBool exclude) {
#ifdef __x86_64__
if (__builtin_cpu_supports("popcnt") && __builtin_cpu_supports("bmi2")) {
return select_bits_pext(dest, src, mask, len, exclude);
}
#endif
HsWord bit_mask;
if (exclude) {
bit_mask = -1;
} else {
bit_mask = 0;
}
HsInt off = 0; // offset in bits into `dest`
for (size_t i = 0; i < len; i++) {
HsWord x = src[i];
HsWord m = mask[i] ^ bit_mask;

// pext
HsWord y = 0;
HsInt count = 0;
if (m == -1) {
y = x;
count = sizeof(HsWord) * 8;
} else {
HsWord bb = 1;
for (; m != 0; bb <<= 1) {
if (x & m & -m) {
y |= bb;
}
m &= m - 1;
}
if (sizeof(HsWord) == 8) {
count = __builtin_ctzll(bb);
} else {
count = __builtin_ctzl(bb);
}
}

if (sizeof(HsWord) == 8) {
// 64 bit
HsInt off_words = off >> 6;
HsInt off_bits = off & 0x3f;
if (off_bits == 0) {
dest[off_words] = y;
} else {
dest[off_words] |= y << off_bits;
dest[off_words + 1] = y >> (64 - off_bits);
}
off += count;
} else {
// 32 bit
HsInt off_words = off >> 5;
HsInt off_bits = off & 0x1f;
if (off_bits == 0) {
dest[off_words] = y;
} else {
dest[off_words] |= y << off_bits;
dest[off_words + 1] = y >> (32 - off_bits);
}
off += count;
}
}
return off;
}
16 changes: 16 additions & 0 deletions src/Data/Bit/Immutable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,14 @@ invertBits xs = runST $ do
--
-- @since 0.1
selectBits :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit
#ifdef UseSIMD
selectBits (BitVec 0 iLen iArr) (BitVec 0 xLen xArr) | modWordSize len == 0 = runST $ do
marr <- newByteArray (len `shiftR` 3)
n <- selectBitsC marr xArr iArr (divWordSize len) False
BitVec 0 n <$> unsafeFreezeByteArray marr
where
len = min iLen xLen
#endif
selectBits is xs = runST $ do
xs1 <- U.thaw xs
n <- selectBitsInPlace is xs1
Expand All @@ -502,6 +510,14 @@ selectBits is xs = runST $ do
--
-- @since 0.1
excludeBits :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit
#ifdef UseSIMD
excludeBits (BitVec 0 iLen iArr) (BitVec 0 xLen xArr) | modWordSize len == 0 = runST $ do
marr <- newByteArray (len `shiftR` 3)
n <- selectBitsC marr xArr iArr (divWordSize len) True
BitVec 0 n <$> unsafeFreezeByteArray marr
where
len = min iLen xLen
#endif
excludeBits is xs = runST $ do
xs1 <- U.thaw xs
n <- excludeBitsInPlace is xs1
Expand Down
9 changes: 9 additions & 0 deletions src/Data/Bit/SIMD.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ module Data.Bit.SIMD
, reverseBitsC
, bitIndexC
, nthBitIndexC
, selectBitsC
) where

import Control.Monad.ST
Expand Down Expand Up @@ -145,3 +146,11 @@ nthBitIndexC :: ByteArray -> Int -> Bool -> Int -> Int
nthBitIndexC (ByteArray arg#) (I# len#) bit (I# n#) =
I# (nth_bit_index arg# len# bit n#)
{-# INLINE nthBitIndexC #-}

foreign import ccall unsafe "_hs_bitvec_select_bits"
select_bits_c :: MutableByteArray# s -> ByteArray# -> ByteArray# -> Int# -> Bool -> IO Int

selectBitsC :: MutableByteArray s -> ByteArray -> ByteArray -> Int -> Bool -> ST s Int
selectBitsC (MutableByteArray res#) (ByteArray arg1#) (ByteArray arg2#) (I# len#) exclude =
unsafeIOToST (select_bits_c res# arg1# arg2# len# exclude)
{-# INLINE selectBitsC #-}
21 changes: 18 additions & 3 deletions test/Tests/SetOps.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ setOpTests = testGroup "Set operations"
, testProperty "invertInPlace middle" prop_invertInPlace_middle
, testProperty "invertInPlaceLong middle" prop_invertInPlaceLong_middle

, adjustOption (\n -> max 500 n :: QuickCheckTests) $ mkGroup "reverseBits" prop_reverseBits
, mkGroup "reverseBits" prop_reverseBits

, testProperty "reverseInPlace" prop_reverseInPlace
, testProperty "reverseInPlaceWords" prop_reverseInPlaceWords
, testProperty "reverseInPlace middle" prop_reverseInPlace_middle
, testProperty "reverseInPlaceLong middle" prop_reverseInPlaceLong_middle

, testProperty "selectBits" prop_selectBits_def
, testProperty "excludeBits" prop_excludeBits_def
, mkGroup2 "selectBits" prop_selectBits_def
, mkGroup2 "excludeBits" prop_excludeBits_def

, mkGroup "countBits" prop_countBits_def
]
Expand All @@ -69,6 +69,21 @@ mkGroup name prop = testGroup name
propMiddleLong (NonNegative x) (NonNegative y) (NonNegative z) =
propMiddle (NonNegative $ x * 31) (NonNegative $ y * 37) (NonNegative $ z * 29)

mkGroup2 :: String -> (U.Vector Bit -> U.Vector Bit -> Property) -> TestTree
mkGroup2 name prop = testGroup name
[ testProperty "simple" prop
, testProperty "simple_long" (\(Large xs) (Large ys) -> prop xs ys)
, testProperty "middle" propMiddle
, testProperty "middle_long" propMiddleLong
]
where
f m = let n = fromIntegral m :: Double in
odd (truncate (exp (abs (sin n) * 10)) :: Integer)
propMiddle (NonNegative from1) (NonNegative len1) (NonNegative excess1) (NonNegative from2) (NonNegative len2) (NonNegative excess2) =
prop (U.slice from1 len1 (U.generate (from1 + len1 + excess1) (Bit . f))) (U.slice from2 len2 (U.generate (from2 + len2 + excess2) (Bit . f)))
propMiddleLong (NonNegative x1) (NonNegative y1) (NonNegative z1) (NonNegative x2) (NonNegative y2) (NonNegative z2) =
propMiddle (NonNegative $ x1 * 31) (NonNegative $ y1 * 37) (NonNegative $ z1 * 29) (NonNegative $ x2 * 31) (NonNegative $ y2 * 37) (NonNegative $ z2 * 29)

prop_generalize1 :: Fun Bit Bit -> Bit -> Property
prop_generalize1 fun x =
applyFun fun x === generalize1 (applyFun fun) x
Expand Down
Loading