module Coda.Relative.Map
  ( Map
  , singleton
  , insert
  , lookup
  , toAscList
  , union
  , split
  , irfoldr
  ) where
import Coda.Internal.Map.BitQueue
import Coda.Relative.Class
import Coda.Relative.Delta.Type
import Control.Lens
import Data.Default
import Data.Hashable
import Data.Function (on)
import Data.Monoid
import GHC.Exts
import Prelude hiding (lookup)
type Size = Int
data Map k a
  = Bin  !Size  !Delta !k !a !(Map k a) !(Map k a)
  | Tip
  deriving Show
type role Map nominal nominal
instance Relative (Map k a) where
  rel _ Tip = Tip
  rel 0 m   = m 
  rel d (Bin s d' k a l r) = Bin s (d+d') k a l r
size :: Map k a -> Int
size (Bin s _ _ _ _ _) = s
size Tip = 0
irfoldr :: (Delta -> k -> a -> r -> r) -> r -> Delta -> Map k a -> r
irfoldr _ z !_ Tip = z
irfoldr f z d (Bin _ d' k a l r) | !d'' <- d <> d' = irfoldr f (f d'' k a (irfoldr f z d'' r)) d'' l
toAscList :: (StrictRelativeOrder k, Relative a) => Map k a -> [(k,a)]
toAscList = irfoldr (\d k x xs -> (rel d k,rel d x):xs) [] 0
instance (StrictRelativeOrder k, Relative a, Eq a) => Eq (Map k a) where (==) = on (==) toAscList
instance (StrictRelativeOrder k, Relative a, Ord a) => Ord (Map k a) where compare = on compare toAscList
instance (StrictRelativeOrder k, RelativeOrder a) => RelativeOrder (Map k a)
instance (StrictRelativeOrder k, StrictRelativeOrder a) => StrictRelativeOrder (Map k a)
instance (StrictRelativeOrder k, Hashable k, Relative a, Hashable a) => Hashable (Map k a) where
  hashWithSalt d = hashWithSalt d . toAscList
type instance IxValue (Map k a) = a
type instance Index (Map k a) = k
instance (StrictRelativeOrder k, Relative a) => Ixed (Map k a)
instance (StrictRelativeOrder k, Relative a) => At (Map k a) where
  at !k f m = case lookupTrace k m of
    TraceResult mv d q -> f mv <&> \case
      Nothing -> case mv of
        Nothing -> m
        Just old -> deleteAlong old q m
      Just !new | !nd <- negate d -> case mv of
        Nothing -> insertAlong q (rel nd k) (rel nd new) m
        Just _  -> replaceAlong q (rel nd new) m
  
singleton :: k -> a -> Map k a
singleton k a = Bin 1 mempty k a Tip Tip
instance Default (Map k a) where
  def = Tip
instance (StrictRelativeOrder k, Relative a) => Monoid (Map k a) where
  mempty = Tip
  mappend = union
instance (StrictRelativeOrder k, Relative a) => RelativeMonoid (Map k a)
lookup :: (StrictRelativeOrder k, Relative a) => k -> Map k a -> Maybe a
lookup = go 0 where
  go !_ !_ Tip = Nothing
  go d k (Bin _ d' kx x l r) | !d'' <- d+d' = case compare k (rel d'' kx) of
    LT -> go d'' k l
    GT -> go d'' k r
    EQ -> Just (rel d'' x)
union :: (StrictRelativeOrder k, Relative a) => Map k a -> Map k a -> Map k a
union t1 Tip  = t1
union t1 (Bin _ d k x Tip Tip) = insertRD 0 (rel d k) (rel d x) t1
union (Bin _ d k x Tip Tip) t2 = insertD 0 (rel d k) (rel d x) t2
union Tip t2 = t2
union t1@(Bin _ d1 k1 x1 l1 r1) t2 = case split d1 k1 t2 of
  (l2, r2)
    | ptrEq l1l2 l1 && ptrEq r1r2 r1 -> t1
    | otherwise -> link d1 k1 x1 l1l2 r1r2
    where
      !l1l2 = union l1 l2
      !r1r2 = union r1 r2
insert :: StrictRelativeOrder k => k -> a -> Map k a -> Map k a
insert = insertD 0 
insertD :: (Ord k, Relative k) => Delta -> k -> a -> Map k a -> Map k a
insertD !d !kx !x Tip = Bin 1 (negate d) kx x Tip Tip
insertD d kx x t@(Bin sz dy ky y l r) | !d'' <- d+dy = case compare kx (rel d'' ky) of
  LT | ptrEq l' l -> t
     | otherwise -> balanceL dy ky y l' r
     where !l' = insertD d'' kx x l
  GT | ptrEq r' r -> t
     | otherwise -> balanceR dy ky y l r'
     where !r' = insertD d'' kx x r
  EQ -> Bin sz (negate d) kx x (rel d'' l) (rel d'' r)
insertRD :: (Ord k, Relative k) => Delta -> k -> a -> Map k a -> Map k a
insertRD !d !kx !x Tip = Bin 1 (d) kx x Tip Tip
insertRD d kx x t@(Bin _ dy ky y l r) | !d'' <- d+dy = case compare kx (rel d'' ky) of
  LT | ptrEq l' l -> t
     | otherwise -> balanceL dy ky y l' r
     where !l' = insertRD d'' kx x l
  GT | ptrEq r' r -> t
     | otherwise -> balanceR dy ky y l r'
     where !r' = insertRD d'' kx x r
  EQ -> t
split :: StrictRelativeOrder k => Delta -> k -> Map k a -> (Map k a, Map k a)
split !d0 !k0 t0 = toPair $ go d0 k0 t0 where
  go :: (Relative k, Ord k) => Delta -> k -> Map k a -> StrictPair (Map k a) (Map k a)
  go d k t = case t of
    Tip -> Tip :*: Tip
    Bin _ dx kx x l r -> case compare (rel (ddx) k) kx of
      LT -> let lt :*: gt = go (ddx) k l in lt :*: link dx kx x gt r
      GT -> let lt :*: gt = go (ddx) k r in link dx kx x l lt :*: gt
      EQ -> rel dx l :*: rel dx r
ptrEq :: a -> a -> Bool
ptrEq x y = isTrue# (reallyUnsafePtrEquality# x y)
toPair :: StrictPair a b -> (a, b)
toPair (a :*: b) = (a, b)
link :: Delta -> k -> a -> Map k a -> Map k a -> Map k a
link d kx x Tip r  = insertMin d kx x r
link d kx x l Tip  = insertMax d kx x l
link d kx x l@(Bin sizeL dy ky y ly ry) r@(Bin sizeR dz kz z lz rz)
  | delta*sizeL < sizeR  = balanceL (d+dz) kz z (link (negate dz) kx x l (rel dz lz)) rz
  | delta*sizeR < sizeL  = balanceR (d+dy) ky y ly (link (negate dy) kx x (rel dy ry) r)
  | otherwise            = bin d kx x l r
bin :: Delta -> k -> a -> Map k a -> Map k a -> Map k a
bin d kx x l r = Bin (size l + size r + 1) d kx x l r
insertMax,insertMin :: Delta -> k -> a -> Map k a -> Map k a
insertMax d kx x t = case t of
  Tip -> Bin 1 (negate d) kx x Tip Tip
  Bin _ dy ky y l r -> balanceR dy ky y l (insertMax (d+dy) kx x r)
insertMin d kx x t = case t of
  Tip -> Bin 1 (negate d) kx x Tip Tip
  Bin _ dy ky y l r -> balanceL dy ky y (insertMin (d+dy) kx x l) r
data StrictPair a b = !a :*: !b
data TraceResult a = TraceResult !(Maybe a)  !Delta  !BitQueue
lookupTrace :: (Ord k, Relative k) => k -> Map k a -> TraceResult a
lookupTrace = go mempty emptyQB where
  go :: (Ord k, Relative k) => Delta -> BitQueueB -> k -> Map k a -> TraceResult a
  go !d !q !_ Tip = TraceResult Nothing d (buildQ q)
  go d q k (Bin _ d' kx x l r) | !d'' <- d+d' = case compare k (rel d'' kx) of
    LT -> go d'' (snocQB q False) k l
    GT -> go d'' (snocQB q True) k r
    EQ -> TraceResult (Just x) d'' (buildQ q)
insertAlong :: BitQueue -> k -> a -> Map k a -> Map k a
insertAlong !_ kx x Tip = singleton kx x
insertAlong q kx x (Bin _ d' ky y l r) = case unconsQ q of
  Just (False, tl) -> balanceL d' ky y (insertAlong tl kx x l) r
  Just (True,tl) -> balanceR d' ky y l (insertAlong tl kx x r)
  Nothing -> error "Coda.Relative.Map.insertAlong: failure"
deleteAlong :: any -> BitQueue -> Map k a -> Map k a
deleteAlong old !q0 !m = go (bogus old) q0 m where
  go :: Proxy# () -> BitQueue -> Map k a -> Map k a
  go !_ !_ Tip = Tip
  go foom q (Bin _ d' ky y l r) = case unconsQ q of
    Just (False, tl) -> balanceR d' ky y (go foom tl l) r
    Just (True, tl) -> balanceL d' ky y l (go foom tl r)
    Nothing -> glue l r
bogus :: a -> Proxy# ()
bogus _ = proxy#
replaceAlong :: BitQueue -> a -> Map k a -> Map k a
replaceAlong !_ _ Tip = error "Coda.Relative.Map.replaceAlong: failure"
replaceAlong q x (Bin sz d' ky y l r) = case unconsQ q of
  Just (False, tl) -> Bin sz d' ky y (replaceAlong tl x l) r
  Just (True,tl) -> Bin sz d' ky y l (replaceAlong tl x r)
  Nothing -> Bin sz d' ky x l r
balanceL :: Delta -> k -> a -> Map k a -> Map k a -> Map k a
balanceL d k x l r = case r of
  Tip -> case l of
    Tip -> Bin 1 d k x Tip Tip
    Bin _ _ _ _ Tip Tip -> Bin 2 d k x l Tip
    Bin _ ld lk lx Tip (Bin _ lrd lrk lrx _ _) -> Bin 3 (d+ld+lrd) lrk lrx (Bin 1 (lrd) lk lx Tip Tip) (Bin 1 (ldlrd) k x Tip Tip)
    Bin _ ld lk lx ll@Bin{} Tip -> Bin 3 (d+ld) lk lx ll (Bin 1 (ld) k x Tip Tip)
    Bin ls ld lk lx ll@(Bin lls _ _ _ _ _) lr@(Bin lrs lrd lrk lrx lrl lrr)
      | lrs < ratio*lls -> Bin (1+ls) (d+ld) lk lx ll (Bin (1+lrs) (ld) k x (rel ld lr) Tip)
      | otherwise -> Bin (1+ls) (d+ld+lrd) lrk lrx 
        (Bin (1+lls+size lrl) (lrd) lk lx ll (rel lrd lrl))
        (Bin (1+size lrr) (lrdld) k x (rel (ld+lrd) lrr) Tip)
  Bin rs _ _ _ _ _ -> case l of
    Tip -> Bin (1+rs) d k x Tip r
    Bin ls ld lk lx ll lr
      | ls > delta*rs  -> case (ll, lr) of
        (Bin lls _ _ _ _ _, Bin lrs lrd lrk lrx lrl lrr)
          | lrs < ratio*lls -> Bin (1+ls+rs) (d+ld) lk lx ll (Bin (1+rs+lrs) (ld) k x (rel ld lr) r)
          | otherwise -> Bin (1+ls+rs) (d+ld+lrd) lrk lrx 
            (Bin (1+lls+size lrl) (lrd) lk lx ll (rel lrd lrl))
            (Bin (1+rs+size lrr) (lrdld) k x (rel (lrd+ld) lrr) r)
        (_, _) -> error "Coda.Relative.Map.balanceL: failure"
      | otherwise -> Bin (1+ls+rs) d k x l r
balanceR :: Delta -> k -> a -> Map k a -> Map k a -> Map k a
balanceR d k x l r = case l of
  Tip -> case r of
    Tip -> Bin 1 d k x Tip Tip
    Bin _ _ _ _ Tip Tip -> Bin 2 d k x Tip r
    Bin _ rd rk rx Tip rr@Bin{} -> Bin 3 (d+rd) rk rx (Bin 1 (rd) k x Tip Tip) rr
    Bin _ rd rk rx (Bin _ rld rlk rlx _ _) Tip -> Bin 3 (d+rd+rld) rlk rlx (Bin 1 (rldrd) k x Tip Tip) (Bin 1 (rld) rk rx Tip Tip)
    Bin rs rd rk rx rl@(Bin rls rld rlk rlx rll rlr) rr@(Bin rrs _ _ _ _ _)
      | rls < ratio*rrs -> Bin (1+rs) (d+rd) rk rx (Bin (1+rls) (rd) k x Tip (rel rd rl)) rr
      | otherwise -> Bin (1+rs) (d+rd+rld) rlk rlx
        (Bin (1+size rll) (rldrd) k x Tip (rel (rd+rld) rll))
        (Bin (1+rrs+size rlr) (rld) rk rx (rel rld rlr) rr)
  Bin ls _ _ _ _ _ -> case r of
    Tip -> Bin (1+ls) d k x l Tip
    Bin rs rd rk rx rl rr
      | rs > delta*ls  -> case (rl, rr) of
        (Bin rls rld rlk rlx rll rlr, Bin rrs _ _ _ _ _)
          | rls < ratio*rrs -> Bin (1+ls+rs) (d+rd) rk rx (Bin (1+ls+rls) (rd) k x l (rel rd rl)) rr
          | otherwise -> Bin (1+ls+rs) (d+rd+rld) rlk rlx
            (Bin (1+ls+size rll) (rdrld) k x l (rel (rd+rld) rll))
            (Bin (1+rrs+size rlr) (rld) rk rx (rel rld rlr) rr)
        (_, _) -> error "Coda.Relative.Map.balanceR: failure"
      | otherwise -> Bin (1+ls+rs) d k x l r
data MinView k a = MinView !Delta !k !a !(Map k a)
data MaxView k a = MaxView !Delta !k !a !(Map k a)
glue :: Map k a -> Map k a -> Map k a
glue Tip r = r
glue l Tip = l
glue (Bin sl dl kl xl ll lr) (Bin sr dr kr xr rl rr)
  | sl > sr   = let !(MaxView dm km m l') = maxViewSure dl kl xl ll lr in balanceR dm km m l' (Bin sr (drdm) kr xr rl rr)
  | otherwise = let !(MinView dm km m r') = minViewSure dr kr xr rl rr in balanceL dm km m (Bin sl (dldm) kl xl ll lr) r'
minViewSure :: Delta -> k -> a -> Map k a -> Map k a -> MinView k a
minViewSure = go where
  go !d k x Tip r = MinView d k x r
  go d k x (Bin _ dl kl xl ll lr) r = case go (d+dl) kl xl ll lr of
    MinView dm km xm l' -> MinView dm km xm (balanceR (dm) k x (rel dm l') r)
maxViewSure :: Delta -> k -> a -> Map k a -> Map k a -> MaxView k a
maxViewSure = go where
  go !d k x l Tip = MaxView d k x l
  go d k x l (Bin _ dr kr xr rl rr) = case go (d+dr) kr xr rl rr of
    MaxView dm km xm r' -> MaxView dm km xm (balanceL (dm) k x l (rel dm r'))
delta,ratio :: Int
delta = 3
ratio = 2