module Data.Struct.TH (makeStruct) where
import Control.Monad (when, zipWithM)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Either (partitionEithers)
import Data.Primitive
import Data.Struct
import Data.Struct.Internal (Dict(Dict), initializeUnboxedField, st)
import Data.List (groupBy, nub)
import Language.Haskell.TH
import Language.Haskell.TH.Syntax (VarStrictType)
#ifdef HLINT
#endif
data StructRep = StructRep
{ srState :: Name
, srName :: Name
, srTyVars :: [TyVarBndr]
, srDerived :: [Name]
, srCxt :: Cxt
, srConstructor :: Name
, srMembers :: [Member]
} deriving Show
data Member = Member
{ _memberRep :: Representation
, memberName :: Name
, _memberType :: Type
}
deriving Show
data Representation = BoxedField | UnboxedField | Slot
deriving Show
makeStruct :: DecsQ -> DecsQ
makeStruct dsq =
do ds <- dsq
(passthrough, reps) <- partitionEithers <$> traverse computeRep ds
ds's <- traverse (generateCode passthrough) reps
return (passthrough ++ concat ds's)
mkAllocName :: StructRep -> Name
mkAllocName rep = mkName ("alloc" ++ nameBase (srName rep))
mkInitName :: StructRep -> Name
mkInitName rep = mkName ("new" ++ nameBase (srName rep))
computeRep :: Dec -> Q (Either Dec StructRep)
computeRep (DataD c n vs cs ds) =
do state <- validateStateType vs
(conname, confields) <- validateContructor cs
members <- traverse (validateMember state) confields
return $ Right StructRep
{ srState = state
, srName = n
, srTyVars = vs
, srConstructor = conname
, srMembers = members
, srDerived = ds
, srCxt = c
}
computeRep d = return (Left d)
validateContructor :: [Con] -> Q (Name,[VarStrictType])
validateContructor [RecC name fields] = return (name,fields)
validateContructor [_] = fail "Expected a record constructor"
validateContructor xs = fail ("Expected 1 constructor, got " ++ show (length xs))
validateStateType :: [TyVarBndr] -> Q Name
validateStateType xs =
do when (null xs) (fail "state type expected but no type variables found")
case last xs of
PlainTV n -> return n
KindedTV n k
| k == starK -> return n
| otherwise -> fail "state type should have kind *"
validateMember :: Name -> VarStrictType -> Q Member
validateMember s (fieldname,NotStrict,fieldtype) =
do when (occurs s fieldtype)
(fail ("state type may not occur in field `" ++ nameBase fieldname ++ "`"))
return (Member BoxedField fieldname fieldtype)
validateMember s (fieldname,IsStrict,fieldtype) =
do f <- unapplyType fieldtype s
when (occurs s f)
(fail ("state type may only occur in final position in slot `" ++ nameBase fieldname ++ "`"))
return (Member Slot fieldname f)
validateMember s (fieldname,Unpacked,fieldtype) =
do when (occurs s fieldtype)
(fail ("state type may not occur in unpacked field `" ++ nameBase fieldname ++ "`"))
return (Member UnboxedField fieldname fieldtype)
unapplyType :: Type -> Name -> Q Type
unapplyType (AppT f (VarT x)) y | x == y = return f
unapplyType _ _ = fail "Unable to match state type of slot"
generateCode :: [Dec] -> StructRep -> DecsQ
generateCode ds rep = concat <$> sequence
[ generateDataType rep
, generateStructInstance rep
, generateMembers rep
, generateNew rep
, generateAlloc rep
, generateRoles ds rep
]
generateDataType :: StructRep -> DecsQ
generateDataType rep = sequence
[ newtypeD (return (srCxt rep)) (srName rep) (srTyVars rep)
(normalC
(srConstructor rep)
[ strictType
notStrict
[t| Object $(varT (srState rep)) |]
])
(srDerived rep)
]
generateRoles :: [Dec] -> StructRep -> DecsQ
generateRoles ds rep
| hasRoleAnnotation = return []
| otherwise = sequence [ roleAnnotD (srName rep) (computeRoles rep) ]
where
hasRoleAnnotation = any isTargetRoleAnnot ds
isTargetRoleAnnot (RoleAnnotD n _) = n == srName rep
isTargetRoleAnnot _ = False
computeRoles :: StructRep -> [Role]
computeRoles = map (const NominalR) . srTyVars
repType1 :: StructRep -> TypeQ
repType1 rep = repTypeHelper (srName rep) (init (srTyVars rep))
repType :: StructRep -> TypeQ
repType rep = repTypeHelper (srName rep) (srTyVars rep)
repTypeHelper :: Name -> [TyVarBndr] -> TypeQ
repTypeHelper c vs = foldl appT (conT c) (tyVarBndrT <$> vs)
tyVarBndrT :: TyVarBndr -> TypeQ
tyVarBndrT (PlainTV n ) = varT n
tyVarBndrT (KindedTV n k) = sigT (varT n) k
generateStructInstance :: StructRep -> DecsQ
generateStructInstance rep =
[d| instance Struct $(repType1 rep) where struct = Dict
instance Eq $(repType rep) where (==) = eqStruct
|]
generateAlloc :: StructRep -> DecsQ
generateAlloc rep =
do mName <- newName "m"
let m = varT mName
n = length (groupBy isNeighbor (srMembers rep))
allocName = mkAllocName rep
simpleDefinition rep allocName
(forallT [PlainTV mName] (cxt [])
[t| PrimMonad $m => $m ( $(repType1 rep) (PrimState $m) ) |])
[| alloc n |]
generateNew :: StructRep -> DecsQ
generateNew rep =
do this <- newName "this"
let ms = groupBy isNeighbor (srMembers rep)
addName m = do n <- newName (nameBase (memberName m))
return (n,m)
msWithArgs <- traverse (traverse addName) ms
let name = mkInitName rep
body = doE
$ bindS (varP this) (varE (mkAllocName rep))
: (noBindS <$> zipWith (assignN (varE this)) [0..] msWithArgs)
++ [ noBindS [| return $(varE this) |] ]
sequence
[ sigD name (newStructType rep)
, funD name [ clause (varP . fst <$> concat msWithArgs)
(normalB [| st $body |] ) [] ]
]
assignN :: ExpQ -> Int -> [(Name,Member)] -> ExpQ
assignN this _ [(arg,Member BoxedField n _)] =
[| setField $(varE n) $this $(varE arg) |]
assignN this _ [(arg,Member Slot n _)] =
[| set $(varE n) $this $(varE arg)|]
assignN this i us =
do let n = length us
mba <- newName "mba"
let arg0 = fst (head us)
doE $ bindS (varP mba) [| initializeUnboxedField i n (sizeOf $(varE arg0)) $this |]
: [ noBindS [| writeByteArray $(varE mba) j $(varE arg) |]
| (j,(arg,_)) <- zip [0 :: Int ..] us ]
newStructType :: StructRep -> TypeQ
newStructType rep =
do mName <- newName "m"
let m = varT mName
s = [t| PrimState $m |]
obj = repType1 rep
buildType (Member BoxedField _ t) = return t
buildType (Member UnboxedField _ t) = return t
buildType (Member Slot _ f) = [t| $(return f) $s |]
r = foldr (-->)
[t| $m ($obj $s) |]
(buildType <$> srMembers rep)
primPreds = primPred <$> nub [ t | Member UnboxedField _ (VarT t) <- srMembers rep ]
forallRepT rep $ forallT [PlainTV mName] (cxt primPreds)
[t| PrimMonad $m => $r |]
generateMembers :: StructRep -> DecsQ
generateMembers rep
= concat <$>
zipWithM
(generateMember1 rep)
[0..]
(groupBy isNeighbor (srMembers rep))
isNeighbor :: Member -> Member -> Bool
isNeighbor (Member UnboxedField _ t) (Member UnboxedField _ u) = t == u
isNeighbor _ _ = False
generateMember1 :: StructRep -> Int -> [Member] -> DecsQ
generateMember1 rep n [Member BoxedField fieldname fieldtype] =
simpleDefinition rep fieldname
[t| Field $(repType1 rep) $(return fieldtype) |]
[| field n |]
generateMember1 rep n [Member Slot slotname slottype] =
simpleDefinition rep slotname
[t| Slot $(repType1 rep) $(return slottype) |]
[| slot n |]
generateMember1 rep n us =
concat <$> sequence
[ simpleDefinition rep fieldname
(addPrimCxt fieldtype
[t| Field $(repType1 rep) $(return fieldtype) |])
[| unboxedField n i |]
| (i,Member UnboxedField fieldname fieldtype) <- zip [0 :: Int ..] us
]
where
addPrimCxt (VarT t) = forallT [] (cxt [primPred t])
addPrimCxt _ = id
simpleDefinition :: StructRep -> Name -> TypeQ -> ExpQ -> DecsQ
simpleDefinition rep name typ def =
sequence
[ sigD name (forallRepT rep typ)
, simpleValD name def
, pragInlD name Inline FunLike AllPhases
]
simpleValD :: Name -> ExpQ -> DecQ
simpleValD var val = valD (varP var) (normalB val) []
forallRepT :: StructRep -> TypeQ -> TypeQ
forallRepT rep = forallT (init (srTyVars rep)) (cxt [])
(-->) :: TypeQ -> TypeQ -> TypeQ
f --> x = arrowT `appT` f `appT` x
primPred :: Name -> PredQ
primPred t = [t| Prim $(varT t) |]
occurs :: Name -> Type -> Bool
occurs n (AppT f x) = occurs n f || occurs n x
occurs n (VarT m) = n == m
occurs n (ForallT _ _ t) = occurs n t
occurs n (SigT t _) = occurs n t
occurs _ _ = False