module Data.Transient.Primitive.Exts
(
sameMutVar
, casMutVar
, sizeOfArray
, sizeOfMutableArray
, casArray
, sizeOfByteArray
, sizeOfMutableByteArray
, casIntArray
, atomicReadIntArray
, atomicWriteIntArray
, fetchAddIntArray
, fetchSubIntArray
, fetchAndIntArray
, fetchNandIntArray
, fetchOrIntArray
, fetchXorIntArray
, prefetchByteArray0
, prefetchByteArray1
, prefetchByteArray2
, prefetchByteArray3
, prefetchMutableByteArray0
, prefetchMutableByteArray1
, prefetchMutableByteArray2
, prefetchMutableByteArray3
, prefetchValue0
, prefetchValue1
, prefetchValue2
, prefetchValue3
) where
import Control.Applicative
import Control.DeepSeq
import Control.Lens
import Control.Monad
import Control.Monad.ST
import Control.Monad.Primitive
import Control.Monad.Zip
import Data.Foldable as Foldable
import Data.Function (on)
import Data.Primitive.Array
import Data.Primitive.ByteArray
import Data.Primitive.MachDeps
import Data.Primitive.MutVar
import Data.Primitive.Types
import GHC.Exts
import Text.Read
casMutVar :: PrimMonad m => MutVar (PrimState m) a -> a -> a -> m (Int, a)
casMutVar (MutVar m) a b = primitive $ \s -> case casMutVar# m a b s of
(# s', i, r #) -> (# s', (I# i, r) #)
sameMutVar :: MutVar s a -> MutVar s a -> Bool
sameMutVar (MutVar a) (MutVar b) = isTrue# (sameMutVar# a b)
sizeOfArray :: Array a -> Int
sizeOfArray (Array m) = I# (sizeofArray# m)
sizeOfMutableArray :: MutableArray s a -> Int
sizeOfMutableArray (MutableArray m) = I# (sizeofMutableArray# m)
casArray :: PrimMonad m => MutableArray (PrimState m) a -> Int -> a -> a -> m (Int, a)
casArray (MutableArray m) (I# i) x y = primitive $ \s -> case casArray# m i x y s of
(# s', i', z #) -> (# s', (I# i', z) #)
sizeOfByteArray :: ByteArray -> Int
sizeOfByteArray (ByteArray m) = I# (sizeofByteArray# m)
sizeOfMutableByteArray :: MutableByteArray s -> Int
sizeOfMutableByteArray (MutableByteArray m) = I# (sizeofMutableByteArray# m)
casIntArray :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> Int -> Int -> m Int
casIntArray (MutableByteArray m) (I# i) (I# x) (I# y) = primitive $ \s -> case casIntArray# m i x y s of
(# s', i' #) -> (# s', I# i' #)
atomicReadIntArray :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> m Int
atomicReadIntArray (MutableByteArray m) (I# i) = primitive $ \s -> case atomicReadIntArray# m i s of
(# s', i' #) -> (# s', I# i' #)
atomicWriteIntArray :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> Int -> m ()
atomicWriteIntArray (MutableByteArray m) (I# i) (I# j) = primitive_ $ \s -> atomicWriteIntArray# m i j s
fetchAddIntArray :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> Int -> m Int
fetchAddIntArray (MutableByteArray m) (I# i) (I# j) = primitive $ \s -> case fetchAddIntArray# m i j s of
(# s', k #) -> (# s', I# k #)
fetchSubIntArray :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> Int -> m Int
fetchSubIntArray (MutableByteArray m) (I# i) (I# j) = primitive $ \s -> case fetchSubIntArray# m i j s of
(# s', k #) -> (# s', I# k #)
fetchAndIntArray :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> Int -> m Int
fetchAndIntArray (MutableByteArray m) (I# i) (I# j) = primitive $ \s -> case fetchAndIntArray# m i j s of
(# s', k #) -> (# s', I# k #)
fetchNandIntArray :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> Int -> m Int
fetchNandIntArray (MutableByteArray m) (I# i) (I# j) = primitive $ \s -> case fetchNandIntArray# m i j s of
(# s', k #) -> (# s', I# k #)
fetchOrIntArray :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> Int -> m Int
fetchOrIntArray (MutableByteArray m) (I# i) (I# j) = primitive $ \s -> case fetchOrIntArray# m i j s of
(# s', k #) -> (# s', I# k #)
fetchXorIntArray :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> Int -> m Int
fetchXorIntArray (MutableByteArray m) (I# i) (I# j) = primitive $ \s -> case fetchXorIntArray# m i j s of
(# s', k #) -> (# s', I# k #)
prefetchByteArray0, prefetchByteArray1, prefetchByteArray2, prefetchByteArray3 :: PrimMonad m => ByteArray -> Int -> m ()
prefetchByteArray0 (ByteArray m) (I# i) = primitive_ $ \s -> prefetchByteArray0# m i s
prefetchByteArray1 (ByteArray m) (I# i) = primitive_ $ \s -> prefetchByteArray1# m i s
prefetchByteArray2 (ByteArray m) (I# i) = primitive_ $ \s -> prefetchByteArray2# m i s
prefetchByteArray3 (ByteArray m) (I# i) = primitive_ $ \s -> prefetchByteArray3# m i s
prefetchMutableByteArray0, prefetchMutableByteArray1, prefetchMutableByteArray2, prefetchMutableByteArray3 :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> m ()
prefetchMutableByteArray0 (MutableByteArray m) (I# i) = primitive_ $ \s -> prefetchMutableByteArray0# m i s
prefetchMutableByteArray1 (MutableByteArray m) (I# i) = primitive_ $ \s -> prefetchMutableByteArray1# m i s
prefetchMutableByteArray2 (MutableByteArray m) (I# i) = primitive_ $ \s -> prefetchMutableByteArray2# m i s
prefetchMutableByteArray3 (MutableByteArray m) (I# i) = primitive_ $ \s -> prefetchMutableByteArray3# m i s
prefetchValue0, prefetchValue1, prefetchValue2, prefetchValue3 :: PrimMonad m => a -> m ()
prefetchValue0 a = primitive_ $ \s -> prefetchValue0# a s
prefetchValue1 a = primitive_ $ \s -> prefetchValue1# a s
prefetchValue2 a = primitive_ $ \s -> prefetchValue2# a s
prefetchValue3 a = primitive_ $ \s -> prefetchValue3# a s
unI# :: Int -> Int#
unI# (I# x) = x
instance Prim (Ptr a) where
sizeOf# _ = unI# sIZEOF_PTR
alignment# _ = unI# aLIGNMENT_PTR
indexByteArray# arr i = Ptr (indexAddrArray# arr i)
readByteArray# arr i s = case readAddrArray# arr i s of
(# s', a #) -> (# s', Ptr a #)
writeByteArray# arr i (Ptr a) s = writeAddrArray# arr i a s
setByteArray# arr i n (Ptr a) s = case unsafeCoerce# (internal (setAddrArray# arr i n a)) s of
(# s', _ #) -> s'
indexOffAddr# addr i = Ptr (indexAddrOffAddr# addr i)
readOffAddr# addr i s = case readAddrOffAddr# addr i s of
(# s', a #) -> (# s', Ptr a #)
writeOffAddr# addr i (Ptr a) s = writeAddrOffAddr# addr i a s
setOffAddr# addr i n (Ptr a) s = case unsafeCoerce# (internal (setAddrOffAddr# addr i n a)) s of
(# s', _ #) -> s'
foreign import ccall unsafe "primitive-memops.h hsprimitive_memset_Ptr"
setAddrArray# :: MutableByteArray# s -> Int# -> Int# -> Addr# -> IO ()
foreign import ccall unsafe "primitive-memops.h hsprimitive_memset_Ptr"
setAddrOffAddr# :: Addr# -> Int# -> Int# -> Addr# -> IO ()
instance Functor Array where
fmap f !i = runST $ do
let n = length i
o <- newArray n undefined
let go !k
| k == n = return ()
| otherwise = do
a <- indexArrayM i k
writeArray o k (f a)
go (k+1)
go 0
unsafeFreezeArray o
instance Foldable Array where
foldr f z arr = go 0 where
n = length arr
go !k
| k == n = z
| otherwise = f (indexArray arr k) (go (k+1))
foldl f z arr = go (length arr 1) where
go !k
| k < 0 = z
| otherwise = f (go (k1)) (indexArray arr k)
foldr' f z arr = go 0 where
n = length arr
go !k
| k == n = z
| r <- indexArray arr k = r `seq` f r (go (k+1))
foldl' f z arr = go (length arr 1) where
go !k
| k < 0 = z
| r <- indexArray arr k = r `seq` f (go (k1)) r
length (Array ary) = I# (sizeofArray# ary)
instance Traversable Array where
traverse f a = fromListN (length a) <$> traverse f (Foldable.toList a)
instance Applicative Array where
pure a = runST $ newArray 1 a >>= unsafeFreezeArray
(m :: Array (a -> b)) <*> (n :: Array a) = runST $ do
o <- newArray (lm * ln) undefined
outer o 0 0
where
lm = length m
ln = length n
outer :: MutableArray s b -> Int -> Int -> ST s (Array b)
outer o !i p
| i < lm = do
f <- indexArrayM m i
inner o i 0 f p
| otherwise = unsafeFreezeArray o
inner :: MutableArray s b -> Int -> Int -> (a -> b) -> Int -> ST s (Array b)
inner o i !j f !p
| j < ln = do
x <- indexArrayM n j
writeArray o p (f x)
inner o i (j + 1) f (p + 1)
| otherwise = outer o (i + 1) p
instance Monad Array where
return = pure
(>>) = (*>)
fail _ = empty
m >>= f = foldMap f m
instance MonadPlus Array where
mzero = empty
mplus = (<|>)
instance MonadZip Array where
mzipWith (f :: a -> b -> c) m n = runST $ do
o <- newArray l undefined
go o 0
where
l = min (length m) (length n)
go :: MutableArray s c -> Int -> ST s (Array c)
go o !i
| i < l = do
a <- indexArrayM m i
b <- indexArrayM n i
writeArray o i (f a b)
go o (i + 1)
| otherwise = unsafeFreezeArray o
munzip m = (fmap fst m, fmap snd m)
instance Alternative Array where
empty = runST $ newArray 0 undefined >>= unsafeFreezeArray
m@(Array pm) <|> n@(Array pn) = runST $ case length m of
lm@(I# ilm) -> case length n of
ln@(I# iln) -> do
o@(MutableArray po) <- newArray (lm + ln) undefined
primitive_ $ \s -> case copyArray# pm 0# po 0# ilm s of
s' -> copyArray# pn 0# po ilm iln s'
unsafeFreezeArray o
instance Monoid (Array a) where
mempty = empty
mappend = (<|>)
instance Eq a => Eq (Array a) where
(==) = (==) `on` Foldable.toList
instance Ord a => Ord (Array a) where
compare = compare `on` Foldable.toList
instance Show a => Show (Array a) where
showsPrec d arr = showParen (d > 10) $
showString "fromList " . showsPrec 11 (Foldable.toList arr)
instance Read a => Read (Array a) where
readPrec = parens $ prec 10 $ do
Ident "fromList" <- lexP
m <- step readPrec
return (fromList m)
instance IsList (Array a) where
type Item (Array a) = a
toList = Foldable.toList
fromListN n xs0 = runST $ do
arr <- newArray n undefined
let go !_ [] = return ()
go k (x:xs) = writeArray arr k x >> go (k+1) xs
go 0 xs0
unsafeFreezeArray arr
fromList xs = fromListN (Prelude.length xs) xs
instance NFData a => NFData (Array a) where
rnf a0 = go a0 (length a0) 0 where
go !a !n !i
| i >= n = ()
| otherwise = rnf (indexArray a i) `seq` go a n (i+1)
type instance Index (Array a) = Int
type instance IxValue (Array a) = a
instance Ixed (Array a) where
ix i f m
| 0 <= i && i < n = go <$> f (indexArray m i)
| otherwise = pure m
where
n = sizeOfArray m
go a = runST $ do
o <- newArray n undefined
copyArray o 0 m 0 n
writeArray o i a
unsafeFreezeArray o
instance AsEmpty (Array a) where
_Empty = nearly empty null
instance Cons (Array a) (Array b) a b where
_Cons = prism hither yon where
hither (a,m) | n <- sizeOfArray m = runST $ do
o <- newArray (n + 1) a
copyArray o 1 m 0 n
unsafeFreezeArray o
yon m
| n <- sizeOfArray m
, n > 0 = Right
( indexArray m 0
, runST $ do
o <- newArray (n 1) undefined
copyArray o 0 m 1 (n 1)
unsafeFreezeArray o
)
| otherwise = Left empty
instance Snoc (Array a) (Array b) a b where
_Snoc = prism hither yon where
hither (m,a) | n <- sizeOfArray m = runST $ do
o <- newArray (n + 1) a
copyArray o 0 m 0 n
unsafeFreezeArray o
yon m
| n <- sizeOfArray m
, n > 0 = Right
( runST $ do
o <- newArray (n 1) undefined
copyArray o 0 m 0 (n 1)
unsafeFreezeArray o
, indexArray m 0
)
| otherwise = Left empty
instance Eq (MutableArray s a) where
(==) = sameMutableArray
instance Monoid ByteArray where
mempty = runST $ newByteArray 0 >>= unsafeFreezeByteArray
mappend m n = runST $ do
o <- newByteArray (lm + ln)
copyByteArray o 0 m 0 lm
copyByteArray o lm n 0 ln
unsafeFreezeByteArray o
where lm = sizeOfByteArray m
ln = sizeOfByteArray n
instance AsEmpty ByteArray where
_Empty = nearly mempty $ \s -> sizeOfByteArray s == 0
instance Eq (MutableByteArray s) where
(==) = sameMutableByteArray