module Data.Struct.Internal.Label where
import Control.Exception
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Bits
import Data.Struct.Internal
import Data.Word
#ifdef HLINT
#endif
type Key = Word64
midBound :: Key
midBound = unsafeShiftR maxBound 1
key :: Field Label Key
key = field 0
next :: Slot Label Label
next = slot 1
prev :: Slot Label Label
prev = slot 2
newtype Label s = Label (Object s)
instance Eq (Label s) where (==) = eqStruct
instance Struct Label
makeLabel :: PrimMonad m => Key -> Label (PrimState m) -> Label (PrimState m) -> m (Label (PrimState m))
makeLabel a p n = st $ do
this <- alloc 3
setField key this a
set next this n
set prev this p
return this
new :: PrimMonad m => m (Label (PrimState m))
new = makeLabel midBound Nil Nil
delete :: PrimMonad m => Label (PrimState m) -> m ()
delete this = st $ unless (isNil this) $ do
p <- get prev this
n <- get next this
unless (isNil p) $ do
set next p n
set prev this Nil
unless (isNil n) $ do
set prev n p
set next this Nil
insertAfter :: PrimMonad m => Label (PrimState m) -> m (Label (PrimState m))
insertAfter this = st $ do
when (isNil this) $ throw NullPointerException
v0 <- getField key this
n <- get next this
v1 <- if isNil n
then return maxBound
else getField key n
fresh <- makeLabel (v0 + unsafeShiftR (v1 v0) 1) this n
set next this fresh
unless (isNil n) $ set prev n fresh
growRight this v0 n 2
return fresh
where
growRight :: Label s -> Key -> Label s -> Word64 -> ST s ()
growRight !n0 !_ Nil !j = growLeft n0 j
growRight n0 v0 nj j = do
vj <- getField key nj
if vjv0 < j*j
then do
nj' <- get next nj
growRight n0 v0 nj' (j+1)
else do
n1 <- get next n0
balance n1 v0 (delta (vjv0) j) j
growLeft :: Label s -> Word64 -> ST s ()
growLeft !c !j = do
p <- get prev c
if isNil p
then balance c 0 (delta maxBound j) j
else do
vp <- getField key p
p' <- get prev p
let !j' = j+1
if maxBound vp < j'*j'
then growLeft p' j'
else balance c vp (delta (maxBoundvp) j') j'
balance :: Label s -> Key -> Key -> Word64 -> ST s ()
balance !_ !_ !_ 0 = return ()
balance Nil _ _ _ = return ()
balance c v dv j = do
let !v' = v + dv
setField key c v'
n <- get next c
balance n v' dv (j1)
cutAfter :: PrimMonad m => Label (PrimState m) -> m ()
cutAfter this = st $ do
when (isNil this) $ throw NullPointerException
n <- get next this
unless (isNil n) $ do
set next this Nil
set prev n Nil
cutBefore :: PrimMonad m => Label (PrimState m) -> m ()
cutBefore this = st $ do
when (isNil this) $ throw NullPointerException
p <- get prev this
unless (isNil p) $ do
set next p Nil
set prev this Nil
least :: PrimMonad m => Label (PrimState m) -> m (Label (PrimState m))
least xs0
| isNil xs0 = throw NullPointerException
| otherwise = st $ go xs0 where
go :: Label s -> ST s (Label s)
go this = do
p <- get prev this
if isNil p
then return this
else go p
greatest :: PrimMonad m => Label (PrimState m) -> m (Label (PrimState m))
greatest xs0
| isNil xs0 = throw NullPointerException
| otherwise = st $ go xs0 where
go :: Label s -> ST s (Label s)
go this = do
n <- get next this
if isNil n
then return this
else go n
compareM :: PrimMonad m => Label (PrimState m) -> Label (PrimState m) -> m Ordering
compareM i j
| isNil i || isNil j = throw NullPointerException
| otherwise = compare <$> getField key i <*> getField key j
delta :: Key -> Word64 -> Key
delta m j = max 1 $ quot m (j+1)
value :: PrimMonad m => Label (PrimState m) -> m Key
value this = getField key this
keys :: PrimMonad m => Label (PrimState m) -> m [Key]
keys this = st $
if isNil this
then return []
else do
x <- getField key this
n <- get next this
xs <- keys n
return (x:xs)