module Concurrent.Thread
( withCapability
, currentThread
, currentCapability
) where
import Concurrent.Primitive.Class
import Control.Monad
import Control.Monad.Catch
import Control.Monad.Primitive
import GHC.Prim
import GHC.Exts
foreign import prim "pinThreadzh" pinThread# :: State# s -> (# State# s, Int# #)
foreign import prim "unpinThreadzh" unpinThread# :: State# s -> (# State# s #)
foreign import prim "currentThreadzh" currentThread# :: State# s -> (# State# s, Int# #)
foreign import prim "currentCapabilityzh" currentCapability# :: State# s -> (# State# s, Int# #)
pinThread :: MonadPrimIO m => m Bool
pinThread = primitive $ \s -> case pinThread# s of
(# s', b #) -> (# s', isTrue# b #)
unpinThread :: MonadPrimIO m => Bool -> m ()
unpinThread b = unless b $ primitive $ \s -> case unpinThread# s of
(# s' #) -> (# s', () #)
withCapability :: (MonadMask m, MonadPrimIO m) => m a -> m a
withCapability m = bracket pinThread unpinThread (const m)
currentThread :: PrimMonad m => m Int
currentThread = primitive $ \s -> case currentThread# s of
(# s', i #) -> (# s', I# i #)
currentCapability :: PrimMonad m => m Int
currentCapability = primitive $ \s -> case currentCapability# s of
(# s', i #) -> (# s', I# i #)