module Concurrent.Struct.Heap
(
Heap(..)
, newHeap
, insert
, extractMin
, meld
, delete
, setKey
, Node(..)
, insertNode
, extractMinNode
, meldNode
, deleteNode
, setKeyNode
, key
, value
) where
import Concurrent.Primitive.Class
import Concurrent.Struct.Ref
import Control.Monad
import Data.Struct.Internal
import Data.Struct.TH
import GHC.ST
makeStruct [d|
data Node k v s = Node
{ parent, child, sibling :: !(Node k v s)
, key :: k
, value :: v
} |]
newtype Heap k v s = Heap (Ref (Node k v) s)
type TransientHeap = Node
newHeap :: MonadPrim s m => m (Heap k v s)
newHeap = primST $ Heap <$> newRef Nil
modify :: MonadPrim s m => (TransientHeap k v s -> ST s (a, TransientHeap k v s)) -> Heap k v s -> m a
modify f (Heap h) = primST $ do
r <- readRef h
(a, r') <- f r
writeRef h r'
return a
modify_ :: MonadPrim s m => (TransientHeap k v s -> ST s (TransientHeap k v s)) -> Heap k v s -> m ()
modify_ f (Heap h) = primST $ do
r <- readRef h
r' <- f r
writeRef h r'
insert :: (MonadPrim s m, Ord k) => k -> v -> Heap k v s -> m (Node k v s)
insert k v h = modify (insertNode k v) h
insertNode :: Ord k => k -> v -> TransientHeap k v s -> ST s (Node k v s, TransientHeap k v s)
insertNode k v r
| isNil r = do
n <- newNode Nil Nil Nil k v
return (n,n)
| otherwise = do
n <- newNode Nil Nil Nil k v
r' <- meldNode r n
return (n,r')
extractMin :: (MonadPrim s m, Ord k) => Heap k v s -> m (Node k v s)
extractMin = modify extractMinNode
extractMinNode :: Ord k => TransientHeap k v s -> ST s (Node k v s, TransientHeap k v s)
extractMinNode r
| isNil r = return (r, r)
| otherwise = do
r' <- combineSiblings r
return (r, r')
meld :: (MonadPrim s m, Ord k) => Heap k v s -> Heap k v s -> m ()
meld (Heap p) (Heap q) = primST $ do
a <- readRef p
b <- readRef q
c <- meldNode a b
writeRef p c
writeRef q Nil
meldNode :: Ord k => TransientHeap k v s -> TransientHeap k v s -> ST s (TransientHeap k v s)
meldNode p q = do
a <- getField key p
b <- getField key q
if a < b then p <$ attachChild p q
else q <$ attachChild q p
attachChild :: Node k v s -> Node k v s -> ST s ()
attachChild p c = do
s <- get child p
set child p c
set sibling c s
set parent c p
combineSiblings :: Ord k => Node k v s -> ST s (Node k v s)
combineSiblings p = do
pairChildren p
linkChildren p
pairChildren :: Ord k => Node k v s -> ST s ()
pairChildren p = do
lc <- get child p
unless (isNil lc) $ do
set child p Nil
go lc
where
go c1 = do
c2 <- get sibling c1
if isNil c2 then attachChild p c1
else do
set sibling c1 Nil
n <- get sibling c2
set sibling c2 Nil
c3 <- meldNode c1 c2
attachChild p c3
unless (isNil n) $ go n
linkChildren :: Ord k => Node k v s -> ST s (Node k v s)
linkChildren p = do
r <- get child p
if isNil r then return Nil
else do
c <- get sibling r
set sibling r Nil
go r c
where
go r c
| isNil c = return r
| otherwise = do
c' <- get sibling c
set sibling c Nil
r' <- meldNode r c
go r' c'
delete :: (MonadPrim s m, Ord k) => Heap k v s -> Node k v s -> m ()
delete h n = modify_ (`deleteNode` n) h
deleteNode :: Ord k => TransientHeap k v s -> Node k v s -> ST s (TransientHeap k v s)
deleteNode r n = do
if r == n then snd <$> extractMinNode n
else do
p <- get parent n
cutParent n
q <- snd <$> extractMinNode n
if isNil q then return r
else meldNode p q
setKey :: (MonadPrim s m, Ord k) => Heap k v s -> Node k v s -> k -> m ()
setKey (Heap h) n k = primST $ do
r <- readRef h
r' <- setKeyNode r n k
writeRef h r'
setKeyNode :: Ord k => TransientHeap k v s -> Node k v s -> k -> ST s (TransientHeap k v s)
setKeyNode h n k = do
ok <- getField key n
setField key n k
if n == h && k <= ok then return h
else do
cutParent n
meldNode h n
cutParent :: Ord k => Node k v s -> ST s ()
cutParent n = do
p <- get parent n
x <- get child p
if x == n then do
ns <- get sibling n
set child p ns
else deleteSibling n x
set parent n Nil
set sibling n Nil
deleteSibling :: Node k v s -> Node k v s -> ST s ()
deleteSibling n p = do
q <- get sibling p
if n == q then do
ns <- get sibling n
set sibling q ns
else deleteSibling n q