Skip to content

Commit

Permalink
feat: handle simple local/parameters/global kind of variables
Browse files Browse the repository at this point in the history
  • Loading branch information
Jabolol committed Dec 6, 2024
1 parent c4a1dd3 commit d8d30a9
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 55 deletions.
12 changes: 8 additions & 4 deletions app/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,22 @@ main :: IO ()
main = do
let ast =
T.AST
[ T.Define "example" (T.Lit (T.LInt 21)),
[ T.Define "value" (T.Lit (T.LInt 21)),
T.Define
"$$generated"
( T.If
(T.Lit (T.LInt 1))
( T.Call
( T.Lambda ["x"] (T.Op T.Mult (T.Var "x") (T.Lit (T.LInt 2)))
( T.Lambda
["x"]
[ T.Define "example" (T.Lit (T.LInt 2)),
T.Op T.Mult (T.Var "x") (T.Var "example")
]
)
[T.Var "example"]
[T.Var "value"]
)
( T.Call
( T.Lambda ["y"] (T.Op T.Sub (T.Var "y") (T.Lit (T.LInt 1)))
( T.Lambda ["y"] [T.Op T.Sub (T.Var "y") (T.Lit (T.LInt 1))]
)
[T.Lit (T.LInt 2)]
)
Expand Down
1 change: 1 addition & 0 deletions glados.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ library
llvm-hs-pretty >=0.9.0 && <0.10,
llvm-hs-pure >=9.0.0 && <9.1,
megaparsec >=9.7.0,
mtl >=2.2.2 && <2.3,

hs-source-dirs: lib
default-language: Haskell2010
Expand Down
2 changes: 1 addition & 1 deletion lib/Ast/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ data Expr
| Var String
| Define String Expr
| Call Expr [Expr]
| Lambda [String] Expr
| Lambda [String] [Expr]
| If Expr Expr Expr
| Op Operation Expr Expr
deriving (Show, Eq)
Expand Down
157 changes: 107 additions & 50 deletions lib/Codegen/Codegen.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ImpredicativeTypes #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecursiveDo #-}

Expand All @@ -8,7 +10,8 @@ module Codegen.Codegen where
import qualified Ast.Types as AT
import qualified Codegen.Utils as U
import qualified Control.Monad as CM
import qualified Control.Monad.Fix as Fix
import qualified Control.Monad.Fix as F
import qualified Control.Monad.State as S
import qualified LLVM.AST as AST
import qualified LLVM.AST.Constant as C
import qualified LLVM.AST.IntegerPredicate as IP
Expand All @@ -17,8 +20,24 @@ import qualified LLVM.IRBuilder.Instruction as I
import qualified LLVM.IRBuilder.Module as M
import qualified LLVM.IRBuilder.Monad as IRM

-- | Type alias for the code generation state.
type CodegenState = [(String, AST.Operand)]

-- | Type alias for the monad stack used for code generation.
type MonadCodegen m = (IRM.MonadIRBuilder m, M.MonadModuleBuilder m, Fix.MonadFix m)
type MonadCodegen m =
( IRM.MonadIRBuilder m,
M.MonadModuleBuilder m,
F.MonadFix m,
S.MonadState CodegenState m
)

-- | Helper functions to manage state
getVarBinding :: (MonadCodegen m) => String -> m (Maybe AST.Operand)
getVarBinding name = S.gets (lookup name)

-- | Adds a variable binding to the state.
addVarBinding :: (MonadCodegen m) => String -> AST.Operand -> m ()
addVarBinding name op = S.modify ((name, op) :)

-- | Converts a parameter name to a pair of type and parameter name.
-- The `toParamType` function takes a string and returns a pair of type and parameter name.
Expand All @@ -29,15 +48,17 @@ toParamType param = (T.i32, M.ParameterName $ U.stringToByteString param)
-- | Generates LLVM code for a given abstract syntax tree (AST).
-- The `codegen` function takes an AST and returns the corresponding LLVM module.
codegen :: AT.AST -> AST.Module
codegen (AT.AST exprs) = M.buildModule "$$generated" $ do
IRM.runIRBuilderT IRM.emptyIRBuilder $ mapM_ generateTopLevel exprs
codegen (AT.AST exprs) =
M.buildModule "$$generated" $
IRM.runIRBuilderT IRM.emptyIRBuilder $
S.evalStateT (mapM_ generateTopLevel exprs) []

-- | Generates LLVM code for a top-level expression.
-- The `generateTopLevel` function takes an expression and generates the corresponding LLVM code.
generateTopLevel :: (MonadCodegen m) => AT.Expr -> m ()
generateTopLevel expr = case expr of
AT.Define name (AT.Lit var) -> CM.void $ buildGlobaVariable (AST.mkName name) var
AT.Define name body -> CM.void $ buildLambda (AST.mkName name) [] body
AT.Define name body -> CM.void $ buildLambda (AST.mkName name) [] [body]
_ -> error ("Unsupported top-level expression: " ++ show expr)

-- | Maps binary operators to LLVM instructions.
Expand All @@ -59,16 +80,16 @@ binaryOps =
-- | Generates LLVM code for an if expression.
generateIf :: (MonadCodegen m) => AT.Expr -> AT.Expr -> AT.Expr -> m AST.Operand
generateIf cond then_ else_ = mdo
condValue <- generateExpr [] cond
condValue <- generateExpr cond
test <- I.icmp IP.NE condValue (AST.ConstantOperand $ C.Int 1 0)
I.condBr test thenBlock elseBlock

thenBlock <- IRM.block `IRM.named` "then"
thenValue <- generateExpr [] then_
thenValue <- generateExpr then_
I.br mergeBB

elseBlock <- IRM.block `IRM.named` "else"
elseValue <- generateExpr [] else_
elseValue <- generateExpr else_
I.br mergeBB

mergeBB <- IRM.block `IRM.named` "merge"
Expand All @@ -77,8 +98,8 @@ generateIf cond then_ else_ = mdo
-- | Generates LLVM code for a binary operation.
generateOp :: (MonadCodegen m) => AT.Operation -> AT.Expr -> AT.Expr -> m AST.Operand
generateOp op e1 e2 = do
v1 <- generateExpr [] e1
v2 <- generateExpr [] e2
v1 <- generateExpr e1
v2 <- generateExpr e2
case lookup op binaryOps of
Just instruction -> instruction v1 v2
Nothing -> error $ "Unsupported operator: " ++ show op
Expand All @@ -96,49 +117,85 @@ buildGlobaVariable name value = do
-- | Generates LLVM code for a lambda expression.
-- The `buildLambda` function takes a name, a list of parameter names, and a body expression,
-- and returns an LLVM operand representing the lambda function.
buildLambda :: (MonadCodegen m) => AST.Name -> [String] -> AT.Expr -> m AST.Operand
buildLambda :: (MonadCodegen m) => AST.Name -> [String] -> [AT.Expr] -> m AST.Operand
buildLambda name params body = do
M.function
name
[toParamType param | param <- params]
T.i32
$ \paramOps -> do
let paramMap = zip params paramOps
result <- generateExpr paramMap body
I.ret result
M.function name [toParamType param | param <- params] T.i32 $ \paramOps -> do
oldState <- S.get
CM.forM_ (zip params paramOps) $ uncurry addVarBinding
results <- mapM generateExpr body
S.put oldState
CM.when (null results) I.retVoid
I.ret $ last results

-- | Generates an LLVM operand for an expression.
-- The `generateExpr` function recursively processes different expression types
-- and generates the corresponding LLVM code.
generateExpr :: (MonadCodegen m) => [(String, AST.Operand)] -> AT.Expr -> m AST.Operand
generateExpr paramMap expr = case expr of
AT.Lit (AT.LInt n) ->
pure $ AST.ConstantOperand $ C.Int 32 (fromIntegral n)
AT.Lit (AT.LBool b) ->
pure $ AST.ConstantOperand $ C.Int 1 (if b then 1 else 0)
AT.Op op e1 e2 -> do
v1 <- generateExpr paramMap e1
v2 <- generateExpr paramMap e2
case lookup op binaryOps of
Just instruction -> instruction v1 v2
Nothing -> error $ "Unsupported operator: " ++ show op
generateExpr :: (MonadCodegen m) => AT.Expr -> m AST.Operand
generateExpr expr = case expr of
AT.Lit lit -> generateLiteral lit
AT.Op op e1 e2 -> generateOp op e1 e2
AT.If cond then_ else_ -> generateIf cond then_ else_
AT.Call func args -> do
func' <- case func of
AT.Lambda params body -> do
uniqueName <- IRM.freshName "lambda"
buildLambda uniqueName params body
AT.Var name -> error $ "Calling a variable as a function is not supported yet: " ++ name
_ -> generateExpr paramMap func
args' <- mapM (generateExpr paramMap) args
I.call func' [(arg, []) | arg <- args']
AT.Var name ->
case lookup name paramMap of
Just op -> pure op
Nothing -> do
let globalVarPtr = AST.ConstantOperand (C.GlobalReference (T.ptr T.i32) (AST.mkName name))
I.load globalVarPtr 0
AT.Lambda params body -> do
uniqueName <- IRM.freshName "lambda"
buildLambda uniqueName params body
_ -> error ("Unsupported expression: " ++ show expr)
AT.Call func args -> generateCall func args
AT.Var name -> generateVar name
AT.Define name body -> generateDefine name body
AT.Lambda params body -> generateLambda params body

-- | Generates an LLVM operand for a literal.
-- The `generateLiteral` function takes a literal and returns the corresponding LLVM operand.
generateLiteral :: (MonadCodegen m) => AT.Literal -> m AST.Operand
generateLiteral =
pure . AST.ConstantOperand . \case
AT.LInt n -> C.Int 32 (fromIntegral n)
AT.LBool b -> C.Int 1 (if b then 1 else 0)
_ -> error "Unsupported literal type"

-- | Generates an LLVM operand for a function call.
-- The `generateCall` function takes a function expression and a list of argument expressions,
-- and returns the corresponding LLVM operand.
generateCall :: (MonadCodegen m) => AT.Expr -> [AT.Expr] -> m AST.Operand
generateCall func args = do
func' <- case func of
AT.Lambda params body -> do
uniqueName <- IRM.freshName "lambda"
buildLambda uniqueName params body
AT.Var name ->
error $ "Calling a variable as a function is not supported yet: " ++ name
_ -> generateExpr func
args' <- mapM generateExpr args
I.call func' [(arg, []) | arg <- args']

-- | Generates an LLVM operand for a variable.
-- The `generateVar` function takes a variable name and returns the corresponding LLVM operand.
generateVar :: (MonadCodegen m) => String -> m AST.Operand
generateVar name = do
maybeOp <- getVarBinding name
case maybeOp of
Just op -> pure op
Nothing -> do
let globalVarPtr =
AST.ConstantOperand $
C.GlobalReference (T.ptr T.i32) (AST.mkName name)
I.load globalVarPtr 0

-- | Generates an LLVM operand for a definition.
-- The `generateDefine` function takes a variable name and an expression,
-- and returns the corresponding LLVM operand.
generateDefine :: (MonadCodegen m) => String -> AT.Expr -> m AST.Operand
generateDefine name = \case
AT.Lit var -> do
let constant = case var of
AT.LInt i -> C.Int 32 (fromIntegral i)
AT.LBool b -> C.Int 1 (if b then 1 else 0)
_ -> error ("Local variable cannot be created with value: " ++ show var)
let op = AST.ConstantOperand constant
addVarBinding name op
generateExpr (AT.Lit var)
expr -> error ("Unsupported expression in definition: " ++ show expr)

-- | Generates an LLVM operand for a lambda expression.
-- The `generateLambda` function takes a list of parameter names and a list of body expressions,
-- and returns the corresponding LLVM operand.
generateLambda :: (MonadCodegen m) => [String] -> [AT.Expr] -> m AST.Operand
generateLambda params body = do
uniqueName <- IRM.freshName "lambda"
buildLambda uniqueName params body

0 comments on commit d8d30a9

Please # to comment.