diff --git a/app/Main.hs b/app/Main.hs index 92f6875..5f4016c 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -2,7 +2,8 @@ module Main where -import qualified Ast.Parser as P +-- import qualified Ast.Parser as P +import Ast.Types import qualified Codegen.Codegen as C import qualified Control.Monad as M import qualified Control.Monad.IO.Class as IO @@ -59,16 +60,105 @@ optionsInfo = <> O.header "Scheme-to-LLVM Compiler" ) +sampleProgram :: Program +sampleProgram = + Program + { globals = + [ ("fibonacci", fibonacciFunction), + ("main", mainFunction) + ], + types = [], + sourceFile = "fibonacci.c" + } + where + fibonacciLoc = SrcLoc "fibonacci.c" 1 1 + nParamLoc = SrcLoc "fibonacci.c" 2 3 + ifLoc = SrcLoc "fibonacci.c" 3 3 + returnBaseCaseLoc = SrcLoc "fibonacci.c" 4 5 + recursiveCallLoc = SrcLoc "fibonacci.c" 5 5 + returnRecursiveLoc = SrcLoc "fibonacci.c" 6 5 + mainLoc = SrcLoc "fibonacci.c" 8 1 + resultLoc = SrcLoc "fibonacci.c" 9 3 + returnLoc = SrcLoc "fibonacci.c" 10 3 + + fibonacciFunction = + Function + { funcLoc = fibonacciLoc, + funcName = "fibonacci", + funcType = TFunction (TInt 32) [TInt 32] False, + funcParams = ["n"], + funcBody = + Block + [ If + { ifLoc = ifLoc, + ifCond = Op ifLoc Lte (Var nParamLoc "n" (TInt 32)) (Lit nParamLoc (LInt 1)), + ifThen = Return returnBaseCaseLoc (Just (Var nParamLoc "n" (TInt 32))), + ifElse = + Just $ + Return + returnRecursiveLoc + ( Just + ( Op + recursiveCallLoc + Add + ( Call + recursiveCallLoc + (Var recursiveCallLoc "fibonacci" (TFunction (TInt 32) [TInt 32] False)) + [Op recursiveCallLoc Sub (Var nParamLoc "n" (TInt 32)) (Lit recursiveCallLoc (LInt 1))] + ) + ( Call + recursiveCallLoc + (Var recursiveCallLoc "fibonacci" (TFunction (TInt 32) [TInt 32] False)) + [Op recursiveCallLoc Sub (Var nParamLoc "n" (TInt 32)) (Lit recursiveCallLoc (LInt 2))] + ) + ) + ) + } + ] + } + + mainFunction = + Function + { funcLoc = mainLoc, + funcName = "$$generated", + funcType = TFunction (TInt 32) [] False, + funcParams = [], + funcBody = + Block + [ Declaration + { declLoc = resultLoc, + declName = "n", + declType = TInt 32, + declInit = Just (Lit resultLoc (LInt 8)) + }, + Declaration + { declLoc = resultLoc, + declName = "result", + declType = TInt 32, + declInit = + Just + ( Call + resultLoc + (Var resultLoc "fibonacci" (TFunction (TInt 32) [TInt 32] False)) + [Var resultLoc "n" (TInt 32)] + ) + }, + Return returnLoc (Just (Var returnLoc "result" (TInt 32))) + ] + } + compile :: String -> String -> Bool -> E.ExceptT CompileError IO String -compile input source verbose = do - ast <- case P.parse input source of - Left err -> E.throwE $ ParseError err - Right res -> return res +compile _ _ verbose = do + -- ast <- case P.parse input source of + -- Left err -> E.throwE $ ParseError err + -- Right res -> return res + + let ast = sampleProgram IO.liftIO $ logMsg verbose (TL.unpack $ PS.pShow ast) case C.codegen ast of - Left err -> E.throwE $ CodegenError (show err) + Left err -> E.throwE $ CodegenError (TL.unpack $ PS.pShow err) Right lmod -> return $ TL.unpack $ LLVM.ppllvm lmod logMsg :: Bool -> String -> IO () @@ -87,7 +177,8 @@ handleError errType errMsg verbose = do main :: IO () main = do Options {input, out, verbose} <- O.execParser optionsInfo - source <- readInput input + -- source <- readInput input + let source = "" logMsg verbose "Starting compilation..." result <- E.runExceptT $ compile (DM.fromMaybe "stdin" input) source verbose diff --git a/glados.cabal b/glados.cabal index 75538f6..aed3191 100644 --- a/glados.cabal +++ b/glados.cabal @@ -34,7 +34,6 @@ library Ast.Types Codegen.Codegen Codegen.Utils - Misc build-depends: base ^>=4.17.2.1, @@ -74,6 +73,7 @@ test-suite glados-test Ast.Parser.UnaryOperationSpec Ast.ParserSpec Misc.MiscSpec + Codegen.CodegenSpec hs-source-dirs: test main-is: Spec.hs @@ -83,5 +83,6 @@ test-suite glados-test glados, hspec, hspec-discover, + llvm-hs-pure >=9.0.0 && <9.1, megaparsec >=9.7.0, - mtl >=2.2.2 && <2.3 + mtl >=2.2.2 && <2.3, diff --git a/lib/Ast/Types.hs b/lib/Ast/Types.hs index ad79751..4b4ac46 100644 --- a/lib/Ast/Types.hs +++ b/lib/Ast/Types.hs @@ -16,6 +16,7 @@ data Literal | LBool Bool | LArray [Literal] | LNull + | LStruct [(String, Literal)] deriving (Show, Eq, Ord) -- | Enhanced type system with size information and qualifiers diff --git a/lib/Codegen/Codegen.hs b/lib/Codegen/Codegen.hs index b674a26..cd6cab4 100644 --- a/lib/Codegen/Codegen.hs +++ b/lib/Codegen/Codegen.hs @@ -1,9 +1,11 @@ {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ImpredicativeTypes #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE MonoLocalBinds #-} {-# LANGUAGE RecursiveDo #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE UndecidableInstances #-} module Codegen.Codegen where @@ -13,243 +15,524 @@ import qualified Control.Monad as CM import qualified Control.Monad.Except as E import qualified Control.Monad.Fix as F import qualified Control.Monad.State as S +import qualified Data.List as L import qualified LLVM.AST as AST import qualified LLVM.AST.Constant as C +import qualified LLVM.AST.Float as FF import qualified LLVM.AST.IntegerPredicate as IP import qualified LLVM.AST.Type as T +import qualified LLVM.AST.Typed as TD +import qualified LLVM.IRBuilder.Constant as IC import qualified LLVM.IRBuilder.Instruction as I import qualified LLVM.IRBuilder.Module as M import qualified LLVM.IRBuilder.Monad as IRM --- | Type alias for the code generation state. -type CodegenState = [(String, AST.Operand)] +-- | Type alias for the local code generation state. +type LocalState = [(String, AST.Operand)] --- | Type for codegen errors +-- | Type alias for the global code generation state. +type GlobalState = [(String, AST.Operand)] + +-- | Type alias for the loop code generation state. +type LoopState = Maybe (AST.Name, AST.Name) + +-- | Combined state for code generation. +data CodegenState = CodegenState + { localState :: LocalState, + globalState :: GlobalState, + loopState :: LoopState + } + deriving (Show) + +-- | Monad constraints for code generation. +type MonadCodegen m = + ( IRM.MonadIRBuilder m, + M.MonadModuleBuilder m, + F.MonadFix m, + S.MonadState CodegenState m, + E.MonadError CodegenError m + ) + +-- | Error types for code generation. data CodegenError + = CodegenError + { errorLoc :: AT.SrcLoc, + errorType :: CodegenErrorType + } + +data CodegenErrorType = UnsupportedTopLevel AT.Expr | UnsupportedOperator AT.Operation + | UnsupportedUnaryOperator AT.UnaryOperation | UnsupportedLiteral AT.Literal + | UnsupportedType AT.Type | UnsupportedGlobalVar AT.Literal | UnsupportedLocalVar AT.Literal | UnsupportedDefinition AT.Expr + | UnsupportedForDefinition AT.Expr + | UnsupportedWhileDefinition AT.Expr | VariableNotFound String | UnsupportedFunctionCall String + | ContinueOutsideLoop + | BreakOutsideLoop deriving (Show) --- | Type alias for the monad stack used for code generation. -type MonadCodegen m = - ( IRM.MonadIRBuilder m, - M.MonadModuleBuilder m, - F.MonadFix m, - S.MonadState CodegenState m, - E.MonadError CodegenError m - ) +instance Show CodegenError where + show (CodegenError loc err) = + AT.srcFile loc + ++ ":" + ++ show (AT.srcLine loc) + ++ ":" + ++ show (AT.srcCol loc) + ++ ": " + ++ showErrorType err + +showErrorType :: CodegenErrorType -> String +showErrorType err = case err of + UnsupportedTopLevel expr -> "Unsupported top-level expression: " ++ show expr + UnsupportedOperator op -> "Unsupported operator: " ++ show op + UnsupportedUnaryOperator op -> "Unsupported unary operator: " ++ show op + UnsupportedLiteral lit -> "Unsupported literal: " ++ show lit + UnsupportedType typ -> "Unsupported type: " ++ show typ + UnsupportedGlobalVar lit -> "Unsupported global variable: " ++ show lit + UnsupportedLocalVar lit -> "Unsupported local variable: " ++ show lit + UnsupportedDefinition expr -> "Unsupported definition: " ++ show expr + UnsupportedForDefinition expr -> "Invalid for loop: " ++ show expr + UnsupportedWhileDefinition expr -> "Invalid while loop: " ++ show expr + VariableNotFound name -> "Variable not found: " ++ name + UnsupportedFunctionCall name -> "Invalid function call: " ++ name + ContinueOutsideLoop -> "Continue statement outside loop" + BreakOutsideLoop -> "Break statement outside loop" --- | Helper functions to manage state -getVarBinding :: (MonadCodegen m) => String -> m (Maybe AST.Operand) -getVarBinding name = S.gets (lookup name) +-- | Variable binding typeclass. +class (Monad m) => VarBinding m where + getVar :: String -> m (Maybe AST.Operand) + addVar :: String -> AST.Operand -> m () + getGlobalVar :: String -> m (Maybe AST.Operand) + addGlobalVar :: String -> AST.Operand -> m () --- | Adds a variable binding to the state. -addVarBinding :: (MonadCodegen m) => String -> AST.Operand -> m () -addVarBinding name op = S.modify ((name, op) :) +instance (MonadCodegen m, Monad m) => VarBinding m where + getVar name = do + state <- S.get + return $ lookup name (localState state) `S.mplus` lookup name (globalState state) + addVar name operand = S.modify (\s -> s {localState = (name, operand) : localState s}) + getGlobalVar name = S.gets (lookup name . globalState) + addGlobalVar name operand = S.modify (\s -> s {globalState = (name, operand) : globalState s}) --- | Converts a parameter name to a pair of type and parameter name. --- The `toParamType` function takes a string and returns a pair of type and parameter name. --- The type is always `i64` (64-bit integer), and the parameter name is the input string. -toParamType :: String -> (T.Type, M.ParameterName) -toParamType param = (T.i64, M.ParameterName $ U.stringToByteString param) +-- | Type conversion to LLVM IR. +class ToLLVM a where + toLLVM :: a -> T.Type --- | Generates LLVM code for a given abstract syntax tree (AST). -codegen :: AT.AST -> Either CodegenError AST.Module -codegen (AT.AST exprs) = +instance ToLLVM AT.Type where + toLLVM expr = case expr of + AT.TInt width -> T.IntegerType (fromIntegral width) + AT.TFloat -> T.FloatingPointType T.FloatFP + AT.TDouble -> T.FloatingPointType T.DoubleFP + AT.TChar -> T.IntegerType 8 + AT.TBoolean -> T.IntegerType 1 + AT.TVoid -> T.void + AT.TPointer t -> T.ptr (toLLVM t) + AT.TArray t (Just n) -> T.ArrayType (fromIntegral n) (toLLVM t) + AT.TArray t Nothing -> T.ptr (toLLVM t) + AT.TFunction ret params var -> T.FunctionType (toLLVM ret) (map toLLVM params) var + AT.TStruct _ fields -> T.StructureType False (map (toLLVM . snd) fields) + AT.TUnion _ variants -> T.StructureType False (map (toLLVM . snd) variants) + AT.TTypedef _ t -> toLLVM t + AT.TMutable t -> toLLVM t + +-- | Generate LLVM code for a program. +codegen :: AT.Program -> Either CodegenError AST.Module +codegen program = E.runExcept $ - M.buildModuleT "$$generated" $ + M.buildModuleT (U.stringToByteString $ AT.sourceFile program) $ IRM.runIRBuilderT IRM.emptyIRBuilder $ - S.evalStateT (mapM_ generateTopLevel exprs) [] - --- | Generates LLVM code for a top-level expression. --- The `generateTopLevel` function takes an expression and generates the corresponding LLVM code. -generateTopLevel :: (MonadCodegen m) => AT.Expr -> m () -generateTopLevel = \case - AT.Define name (AT.Lit var) -> CM.void $ buildGlobaVariable (AST.mkName name) var - AT.Define name (AT.Lambda params body) -> - CM.void $ - buildFunction (AST.mkName name) params (AT.Lambda params body) - AT.Define name body -> CM.void $ buildLambda (AST.mkName name) [] body - expr -> E.throwError $ UnsupportedTopLevel expr - --- | Maps binary operators to LLVM instructions. -binaryOps :: [(AT.Operation, AST.Operand -> AST.Operand -> (IRM.MonadIRBuilder m) => m AST.Operand)] -binaryOps = - [ (AT.Add, I.add), - (AT.Sub, I.sub), - (AT.Mult, I.mul), - (AT.Div, I.sdiv), - (AT.Mod, I.srem), - (AT.Lt, I.icmp IP.SLT), - (AT.Gt, I.icmp IP.SGT), - (AT.Lte, I.icmp IP.SLE), - (AT.Gte, I.icmp IP.SGE), - (AT.Equal, I.icmp IP.EQ), - (AT.Ne, I.icmp IP.NE), - (AT.And, I.and), - (AT.Or, I.or) + S.evalStateT (mapM_ (generateGlobal . snd) (AT.globals program)) (CodegenState [] [] Nothing) + +-- | Generate LLVM code for global expressions. +generateGlobal :: (MonadCodegen m) => AT.Expr -> m () +generateGlobal expr = case expr of + AT.Function {} -> CM.void $ generateFunction expr + _ -> E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedTopLevel expr + +-- | Generate LLVM code for an expression. +class ExprGen a where + generateExpr :: (MonadCodegen m) => a -> m AST.Operand + +instance ExprGen AT.Expr where + generateExpr expr = case expr of + AT.Lit {} -> generateLiteral expr + AT.Var {} -> generateVar expr + AT.Function {} -> generateFunction expr + AT.Declaration {} -> generateDeclaration expr + AT.If {} -> generateIf expr + AT.Block {} -> generateBlock expr + AT.Return {} -> generateReturn expr + AT.Op {} -> generateBinaryOp expr + AT.UnaryOp {} -> generateUnaryOp expr + AT.Call {} -> generateFunctionCall expr + AT.ArrayAccess {} -> generateArrayAccess expr + AT.Cast {} -> generateCast expr + AT.For {} -> generateForLoop expr + AT.While {} -> generateWhileLoop expr + AT.Break {} -> generateBreak expr + AT.Continue {} -> generateContinue expr + AT.Assignment {} -> generateAssignment expr + _ -> E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr + +-- | Generate LLVM code for constants. +generateConstant :: (MonadCodegen m) => AT.Literal -> m C.Constant +generateConstant lit = case lit of + AT.LInt n -> return $ C.Int 32 (fromIntegral n) + AT.LChar c -> return $ C.Int 8 (fromIntegral $ fromEnum c) + AT.LBool b -> return $ C.Int 1 (if b then 1 else 0) + AT.LNull -> return $ C.Null T.i8 + AT.LFloat f -> pure $ C.Float (FF.Single (realToFrac f)) + AT.LArray elems -> do + constants <- mapM generateConstant elems + return $ C.Array (TD.typeOf $ head constants) constants + +-- | Generate LLVM code for literals. +generateLiteral :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateLiteral (AT.Lit _ lit) = do + constant <- generateConstant lit + pure $ AST.ConstantOperand constant +generateLiteral expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr + +-- | Generate LLVM code for binary operations. +generateBinaryOp :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateBinaryOp (AT.Op loc op e1 e2) = do + v1 <- generateExpr e1 + v2 <- generateExpr e2 + case findOperator op of + Just f -> f v1 v2 + Nothing -> E.throwError $ CodegenError loc $ UnsupportedOperator op + where + findOperator op' = L.find ((== op') . opMapping) binaryOperators >>= Just . opFunction +generateBinaryOp expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr + +-- | Binary operation data type. +data BinaryOp m = BinaryOp + { opMapping :: AT.Operation, + opFunction :: AST.Operand -> AST.Operand -> m AST.Operand + } + +-- | List of supported binary operators. +binaryOperators :: (MonadCodegen m) => [BinaryOp m] +binaryOperators = + [ BinaryOp AT.Add I.add, + BinaryOp AT.Sub I.sub, + BinaryOp AT.Mul I.mul, + BinaryOp AT.Div I.sdiv, + BinaryOp AT.Mod I.srem, + BinaryOp AT.BitAnd I.and, + BinaryOp AT.BitOr I.or, + BinaryOp AT.BitXor I.xor, + BinaryOp AT.BitShl I.shl, + BinaryOp AT.BitShr I.ashr, + BinaryOp AT.And I.and, + BinaryOp AT.Or I.or + ] + ++ map mkComparisonOp comparisonOps + where + mkComparisonOp (op, pre) = BinaryOp op (I.icmp pre) + comparisonOps = + [ (AT.Lt, IP.SLT), + (AT.Gt, IP.SGT), + (AT.Lte, IP.SLE), + (AT.Gte, IP.SGE), + (AT.Eq, IP.EQ), + (AT.Ne, IP.NE) + ] + +-- | Unary operation data type. +data UnaryOp m = UnaryOp + { unaryMapping :: AT.UnaryOperation, + unaryFunction :: AST.Operand -> m AST.Operand + } + +-- | List of supported unary operators. +unaryOperators :: (MonadCodegen m) => [UnaryOp m] +unaryOperators = + [ UnaryOp AT.Not (\operand -> I.xor operand (AST.ConstantOperand $ C.Int 1 1)), + UnaryOp AT.BitNot (\operand -> I.xor operand (AST.ConstantOperand $ C.Int 32 (-1))), + UnaryOp AT.Deref (`I.load` 0), + UnaryOp AT.AddrOf pure, + UnaryOp AT.PreInc (\operand -> I.add operand (AST.ConstantOperand $ C.Int 32 1)), + UnaryOp AT.PreDec (\operand -> I.sub operand (AST.ConstantOperand $ C.Int 32 1)), + UnaryOp AT.PostInc (postOp I.add), + UnaryOp AT.PostDec (postOp I.sub) ] + where + postOp op operand = do + result <- I.load operand 0 + I.store operand 0 =<< op result (AST.ConstantOperand $ C.Int 32 1) + pure result --- | Generates LLVM code for an if expression. -generateIf :: (MonadCodegen m) => AT.Expr -> AT.Expr -> AT.Expr -> m AST.Operand -generateIf cond then_ else_ = mdo +-- | Generate LLVM code for unary operations. +generateUnaryOp :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateUnaryOp (AT.UnaryOp _ op expr) = do + operand <- generateExpr expr + case findOperator op of + Just f -> f operand + Nothing -> E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedUnaryOperator op + where + findOperator op' = L.find ((== op') . unaryMapping) unaryOperators >>= Just . unaryFunction +generateUnaryOp expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr + +-- | Generate LLVM code for variable references. +generateVar :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateVar (AT.Var loc name _) = do + maybeVar <- getVar name + case maybeVar of + Just ptr -> case TD.typeOf ptr of + T.PointerType _ _ -> I.load ptr 0 + _ -> return ptr + Nothing -> E.throwError $ CodegenError loc $ VariableNotFound name +generateVar expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr + +-- | Generate LLVM code for blocks. +generateBlock :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateBlock (AT.Block exprs) = do + last <$> traverse generateExpr exprs +generateBlock expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr + +-- | Generate LLVM code for `if` expressions. +generateIf :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateIf (AT.If _ cond then_ else_) = mdo condValue <- generateExpr cond test <- I.icmp IP.NE condValue (AST.ConstantOperand $ C.Int 1 0) I.condBr test thenBlock elseBlock - thenBlock <- IRM.block `IRM.named` "then" + thenBlock <- IRM.block `IRM.named` U.stringToByteString "if.then" thenValue <- generateExpr then_ I.br mergeBB - elseBlock <- IRM.block `IRM.named` "else" - elseValue <- generateExpr else_ + elseBlock <- IRM.block `IRM.named` U.stringToByteString "if.else" + elseValue <- case else_ of + Just e -> generateExpr e + Nothing -> pure $ AST.ConstantOperand $ C.Undef T.void I.br mergeBB - mergeBB <- IRM.block `IRM.named` "merge" + mergeBB <- IRM.block `IRM.named` U.stringToByteString "if.merge" + I.phi [(thenValue, thenBlock), (elseValue, elseBlock)] +generateIf expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr --- | Generates LLVM code for a binary operation. -generateOp :: (MonadCodegen m) => AT.Operation -> AT.Expr -> AT.Expr -> m AST.Operand -generateOp op e1 e2 = do - v1 <- generateExpr e1 - v2 <- generateExpr e2 - case lookup op binaryOps of - Just instruction -> instruction v1 v2 - Nothing -> E.throwError $ UnsupportedOperator op - --- | Generates LLVM code for a literal variable. --- The `buildLiteralVariable` function takes the variable name and its value, -buildGlobaVariable :: (MonadCodegen m) => AST.Name -> AT.Literal -> m AST.Operand -buildGlobaVariable name = \case - AT.LInt i -> M.global name T.i64 (C.Int 64 $ fromIntegral i) - AT.LBool b -> M.global name T.i64 (C.Int 1 $ if b then 1 else 0) - value -> E.throwError $ UnsupportedGlobalVar value - --- | Generates LLVM code for a lambda expression. --- The `buildLambda` function takes a name, a list of parameter names, and a body expression, --- and returns an LLVM operand representing the lambda function. -buildLambda :: (MonadCodegen m) => AST.Name -> [String] -> AT.Expr -> m AST.Operand -buildLambda name params body = do - func' <- M.function name [toParamType param | param <- params] T.i64 $ \paramOps -> do - oldState <- S.get - CM.forM_ (zip params paramOps) $ uncurry addVarBinding - results <- generateExpr body - S.put oldState - I.ret results - addVarBinding (U.nameToString name) func' - return func' - --- | Generates LLVM code for a lambda expression. --- The `buildFunction` will build both a lambda and a named function. --- The function generated will just call the lambda function and return the result. -buildFunction :: (MonadCodegen m) => AST.Name -> [String] -> AT.Expr -> m AST.Operand -buildFunction name params body = do - M.function name [toParamType param | param <- params] T.i64 $ \paramOps -> do - oldState <- S.get - CM.forM_ (zip params paramOps) $ uncurry addVarBinding - results <- generateExpr body - S.put oldState - call' <- I.call results [(arg, []) | arg <- paramOps] - I.ret call' - --- | Generates an LLVM operand for an expression. --- The `generateExpr` function recursively processes different expression types --- and generates the corresponding LLVM code. -generateExpr :: (MonadCodegen m) => AT.Expr -> m AST.Operand -generateExpr = \case - AT.Lit lit -> generateLiteral lit - AT.Op op e1 e2 -> generateOp op e1 e2 - AT.If cond then_ else_ -> generateIf cond then_ else_ - AT.Call func args -> generateCall func args - AT.Var name -> generateVar name - AT.Define name body -> generateDefine name body - AT.Lambda params body -> generateLambda params body - AT.Seq exprs -> generateSeq exprs - --- | Generates an LLVM operand for a literal. --- The `generateLiteral` function takes a literal and returns the corresponding LLVM operand. -generateLiteral :: (MonadCodegen m) => AT.Literal -> m AST.Operand -generateLiteral = \case - AT.LInt n -> pure $ AST.ConstantOperand $ C.Int 64 (fromIntegral n) - AT.LBool b -> pure $ AST.ConstantOperand $ C.Int 1 (if b then 1 else 0) - lit -> E.throwError $ UnsupportedLiteral lit - --- | Generates an LLVM operand for a function call. --- The `generateCall` function takes a function expression and a list of argument expressions, --- and returns the corresponding LLVM operand. -generateCall :: (MonadCodegen m) => AT.Expr -> AT.Expr -> m AST.Operand -generateCall func args = do - func' <- case func of - AT.Lambda params body -> do - uniqueName <- IRM.freshName "lambda" - buildLambda uniqueName params body - AT.Var name -> do - maybeOp <- getVarBinding name - case maybeOp of - Just op -> pure op - Nothing -> E.throwError $ VariableNotFound name - _ -> generateExpr func - case args of - (AT.Seq args') -> do - argOps <- mapM generateExpr args' - I.call func' [(argOp, []) | argOp <- argOps] - _ -> do - args' <- generateExpr args - I.call func' [(args', [])] - --- | Generates an LLVM operand for a variable. --- The `generateVar` function takes a variable name and returns the corresponding LLVM operand. -generateVar :: (MonadCodegen m) => String -> m AST.Operand -generateVar name = do - maybeOp <- getVarBinding name - case maybeOp of - Just op -> pure op +-- | Generate LLVM code for function definitions. +generateFunction :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateFunction (AT.Function _ name (AT.TFunction ret params False) paramNames body) = do + let funcName = AST.Name $ U.stringToByteString name + paramTypes = zipWith mkParam params paramNames + funcType = T.ptr $ T.FunctionType (toLLVM ret) (map fst paramTypes) False + addGlobalVar name $ AST.ConstantOperand $ C.GlobalReference funcType funcName + M.function funcName paramTypes (toLLVM ret) $ \ops -> do + S.modify (\s -> s {localState = []}) + S.zipWithM_ addVar paramNames ops + result <- generateExpr body + I.ret result + where + mkParam t n = (toLLVM t, M.ParameterName $ U.stringToByteString n) +generateFunction expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr + +-- | Generate LLVM code for declarations. +generateDeclaration :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateDeclaration (AT.Declaration _ name typ mInitExpr) = do + let llvmType = toLLVM typ + ptr <- I.alloca llvmType Nothing 0 + case mInitExpr of + Just initExpr -> do + initValue <- generateExpr initExpr + I.store ptr 0 initValue + Nothing -> pure () + addVar name ptr + pure ptr +generateDeclaration expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr + +-- | Generate LLVM code for return statements. +generateReturn :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateReturn (AT.Return _ mExpr) = do + case mExpr of + Just expr -> do + result <- generateExpr expr + I.ret result + pure result Nothing -> do - let globalVarPtr = - AST.ConstantOperand $ - C.GlobalReference (T.ptr T.i64) (AST.mkName name) - I.load globalVarPtr 0 - --- | Generates an LLVM operand for a definition. --- The `generateDefine` function takes a variable name and an expression, --- and returns the corresponding LLVM operand. -generateDefine :: (MonadCodegen m) => String -> AT.Expr -> m AST.Operand -generateDefine name = \case - AT.Lit var -> case var of - AT.LInt i -> do - let op = AST.ConstantOperand (C.Int 64 $ fromIntegral i) - addVarBinding name op - generateExpr (AT.Lit var) - AT.LBool b -> do - let op = AST.ConstantOperand (C.Int 1 $ if b then 1 else 0) - addVarBinding name op - generateExpr (AT.Lit var) - _ -> E.throwError $ UnsupportedLocalVar var - AT.Lambda params body -> buildLambda (AST.mkName name) params body - AT.Var var -> generateVar var - AT.Seq exprs -> generateSeq exprs - expr -> E.throwError $ UnsupportedDefinition expr - --- | Generates an LLVM operand for a lambda expression. --- The `generateLambda` function takes a list of parameter names and a list of body expressions, --- and returns the corresponding LLVM operand. -generateLambda :: (MonadCodegen m) => [String] -> AT.Expr -> m AST.Operand -generateLambda params body = do - uniqueName <- IRM.freshName "lambda" - buildLambda uniqueName params body - --- | Generates an LLVM operand for a sequence of expressions. --- The `generateSeq` function takes a list of expressions and returns the corresponding LLVM operand. -generateSeq :: (MonadCodegen m) => [AT.Expr] -> m AST.Operand -generateSeq = \case - [] -> E.throwError $ UnsupportedTopLevel (AT.Seq []) - [expr] -> generateExpr expr - exprs -> do - CM.forM_ (init exprs) generateExpr - generateExpr (last exprs) + I.retVoid + pure $ AST.ConstantOperand $ C.Undef T.void +generateReturn expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr + +-- | Generate LLVM code for function calls. +generateFunctionCall :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateFunctionCall (AT.Call loc (AT.Var _ name _) args) = do + maybeFunc <- getVar name + case maybeFunc of + Just funcOperand -> case funcOperand of + AST.ConstantOperand (C.GlobalReference _ _) -> do + operandArgs <- mapM generateExpr args + I.call funcOperand (map (,[]) operandArgs) + _ -> do + funcPtr <- I.load funcOperand 0 + operandArgs <- mapM generateExpr args + I.call funcPtr (map (,[]) operandArgs) + Nothing -> E.throwError $ CodegenError loc $ UnsupportedFunctionCall name +generateFunctionCall expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr + +-- | Check the type of an argument. +checkArgumentType :: (MonadCodegen m) => T.Type -> AT.Expr -> m () +checkArgumentType expectedType expr = do + operand <- generateExpr expr + let actualType = TD.typeOf operand + CM.when (actualType /= expectedType) $ + E.throwError $ + CodegenError (U.getLoc expr) $ + UnsupportedFunctionCall "Argument type mismatch" + +-- | Generate LLVM code for array access. +generateArrayAccess :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateArrayAccess (AT.ArrayAccess loc (AT.Var _ name _) indexExpr) = do + maybeVar <- getVar name + ptr <- case maybeVar of + Just arrayPtr -> return arrayPtr + Nothing -> E.throwError $ CodegenError loc $ VariableNotFound name + index <- generateExpr indexExpr + elementPtr <- I.gep ptr [IC.int32 0, index] + I.load elementPtr 0 +generateArrayAccess expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr + +-- | Generate LLVM code for type casts. +generateCast :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateCast (AT.Cast _ typ expr) = do + operand <- generateExpr expr + let fromType = TD.typeOf operand + toType = toLLVM typ + case (fromType, toType) of + (T.IntegerType fromBits, T.IntegerType toBits) + | fromBits < toBits -> I.zext operand toType + (T.IntegerType fromBits, T.IntegerType toBits) + | fromBits > toBits -> I.trunc operand toType + (T.IntegerType _, T.FloatingPointType _) -> I.sitofp operand toType + (T.FloatingPointType _, T.IntegerType _) -> I.fptosi operand toType + (T.FloatingPointType _, T.FloatingPointType _) -> I.fptrunc operand toType + (T.ArrayType _ _, T.PointerType _ _) -> I.bitcast operand toType + (T.ArrayType _ _, T.ArrayType _ _) -> I.bitcast operand toType + (T.IntegerType _, T.PointerType _ _) -> I.inttoptr operand toType + _ -> E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedType typ +generateCast expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr + +-- | Generate LLVM code for for loops. +generateForLoop :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateForLoop (AT.For _ init' cond step body) = mdo + CM.void $ generateExpr init' + + I.br condBlock + + condBlock <- IRM.block `IRM.named` U.stringToByteString "for.cond" + condResult <- generateExpr cond + I.condBr condResult bodyBlock exitBlock + + bodyBlock <- IRM.block `IRM.named` U.stringToByteString "for.body" + + state <- S.gets loopState + S.modify (\s -> s {loopState = Just (stepBlock, exitBlock)}) + + CM.void $ generateExpr body + + S.modify (\s -> s {loopState = state}) + + I.br stepBlock + + stepBlock <- IRM.block `IRM.named` U.stringToByteString "for.step" + CM.void $ generateExpr step + I.br condBlock + + exitBlock <- IRM.block `IRM.named` U.stringToByteString "for.exit" + + pure $ AST.ConstantOperand $ C.Null T.i8 +generateForLoop expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedForDefinition expr + +-- | Generate LLVM code for while loops. +generateWhileLoop :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateWhileLoop (AT.While _ cond body) = mdo + I.br condBlock + + condBlock <- IRM.block `IRM.named` U.stringToByteString "while.cond" + condOperand <- generateExpr cond + I.condBr condOperand bodyBlock exitBlock + + bodyBlock <- IRM.block `IRM.named` U.stringToByteString "while.body" + + state <- S.gets loopState + S.modify (\s -> s {loopState = Just (condBlock, exitBlock)}) + + CM.void $ generateExpr body + + S.modify (\s -> s {loopState = state}) + + I.br condBlock + + exitBlock <- IRM.block `IRM.named` U.stringToByteString "while.exit" + + pure $ AST.ConstantOperand $ C.Null T.i8 +generateWhileLoop expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedWhileDefinition expr + +-- | Generate LLVM code for break statements. +generateBreak :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateBreak (AT.Break loc) = do + state <- S.get + case loopState state of + Just (_, breakBlock) -> do + I.br breakBlock + pure $ AST.ConstantOperand $ C.Undef T.void + Nothing -> E.throwError $ CodegenError loc BreakOutsideLoop +generateBreak expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr + +generateContinue :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateContinue (AT.Continue loc) = do + state <- S.get + case loopState state of + Just (continueBlock, _) -> do + I.br continueBlock + pure $ AST.ConstantOperand $ C.Undef T.void + Nothing -> E.throwError $ CodegenError loc ContinueOutsideLoop +generateContinue expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr + +-- | Generate LLVM code for assignments. +generateAssignment :: (MonadCodegen m) => AT.Expr -> m AST.Operand +generateAssignment (AT.Assignment _ expr valueExpr) = do + value <- generateExpr valueExpr + case expr of + AT.Var _ name _ -> do + maybeVar <- getVar name + case maybeVar of + Just ptr -> do + I.store ptr 0 value + pure value + Nothing -> E.throwError $ CodegenError (U.getLoc expr) $ VariableNotFound name + AT.ArrayAccess _ (AT.Var _ name _) indexExpr -> do + maybeVar <- getVar name + ptr <- case maybeVar of + Just arrayPtr -> return arrayPtr + Nothing -> E.throwError $ CodegenError (U.getLoc expr) $ VariableNotFound name + index <- generateExpr indexExpr + elementPtr <- I.gep ptr [IC.int32 0, index] + I.store elementPtr 0 value + pure value + _ -> E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr +generateAssignment expr = + E.throwError $ CodegenError (U.getLoc expr) $ UnsupportedDefinition expr diff --git a/lib/Codegen/Utils.hs b/lib/Codegen/Utils.hs index d87dc20..e3519bb 100644 --- a/lib/Codegen/Utils.hs +++ b/lib/Codegen/Utils.hs @@ -1,5 +1,6 @@ module Codegen.Utils where +import qualified Ast.Types as AT import qualified Data.ByteString.Short as BS import qualified LLVM.AST as AST @@ -15,3 +16,25 @@ byteStringToString = map (toEnum . fromIntegral) . BS.unpack nameToString :: AST.Name -> String nameToString (AST.Name n) = byteStringToString n nameToString (AST.UnName n) = show n + +-- | Extracts the location from an AST node. +getLoc :: AT.Expr -> AT.SrcLoc +getLoc expr = case expr of + AT.Lit loc _ -> loc + AT.Var loc _ _ -> loc + AT.StructAccess loc _ _ -> loc + AT.ArrayAccess loc _ _ -> loc + AT.UnaryOp loc _ _ -> loc + AT.Call loc _ _ -> loc + AT.If loc _ _ _ -> loc + AT.While loc _ _ -> loc + AT.For loc _ _ _ _ -> loc + AT.Return loc _ -> loc + AT.Break loc -> loc + AT.Continue loc -> loc + AT.Cast loc _ _ -> loc + AT.Declaration loc _ _ _ -> loc + AT.Assignment loc _ _ -> loc + AT.Op loc _ _ _ -> loc + AT.Function loc _ _ _ _ -> loc + AT.Block exprs -> getLoc $ head exprs diff --git a/lib/Misc.hs b/lib/Misc.hs deleted file mode 100644 index 7519081..0000000 --- a/lib/Misc.hs +++ /dev/null @@ -1,4 +0,0 @@ -module Misc (addOne) where - -addOne :: Int -> Int -addOne x = x + 1 diff --git a/test/Codegen/CodegenSpec.hs b/test/Codegen/CodegenSpec.hs new file mode 100644 index 0000000..cae73e6 --- /dev/null +++ b/test/Codegen/CodegenSpec.hs @@ -0,0 +1,249 @@ +module Codegen.CodegenSpec (spec) where + +import qualified Ast.Types as AT +import qualified Codegen.Codegen as C +import qualified Codegen.Utils as U +import qualified LLVM.AST as AST +import qualified LLVM.AST.Global as G +import qualified LLVM.AST.Linkage as L +import qualified LLVM.AST.Type as T +import qualified Test.Hspec as H + +{--# +Simple program which adds two numbers and +returns 1 if the result is greater than 30, +otherwise 0. + +It's equivalent to the following C code: + +``` +int main(void) +{ + int x = 10; + int y = 20; + int result = x + y; + + if (result > 30) { + return 1; + } else { + return 0; + } +} +``` +#--} +simpleSum :: AT.Program +simpleSum = + AT.Program + { AT.globals = [("main", mainDef)], + AT.types = [], + AT.sourceFile = "sample.c" + } + where + sampleLoc = AT.SrcLoc "sample.c" 1 1 + mainDef = + AT.Function + { AT.funcLoc = sampleLoc, + AT.funcName = "main", + AT.funcType = AT.TFunction (AT.TInt 32) [AT.TInt 32] False, + AT.funcParams = [], + AT.funcBody = + AT.Block + [ AT.Declaration + sampleLoc + "x" + (AT.TInt 32) + (Just (AT.Lit sampleLoc (AT.LInt 10))), + AT.Declaration + sampleLoc + "y" + (AT.TInt 32) + (Just (AT.Lit sampleLoc (AT.LInt 20))), + AT.Declaration + sampleLoc + "result" + (AT.TInt 32) + ( Just + ( AT.Op + sampleLoc + AT.Add + (AT.Var sampleLoc "x" (AT.TInt 32)) + (AT.Var sampleLoc "y" (AT.TInt 32)) + ) + ), + AT.If + sampleLoc + ( AT.Op + sampleLoc + AT.Gt + (AT.Var sampleLoc "result" (AT.TInt 32)) + (AT.Lit sampleLoc (AT.LInt 30)) + ) + ( AT.Return + sampleLoc + (Just (AT.Lit sampleLoc (AT.LInt 1))) + ) + ( Just + ( AT.Return + sampleLoc + (Just (AT.Lit sampleLoc (AT.LInt 0))) + ) + ) + ] + } + +spec :: H.Spec +spec = H.describe "Codegen" $ do + H.context "when generating code for a simple sum program" $ do + let getMainFunc mod' = do + let defs = AST.moduleDefinitions mod' + let mainFuncs = + [ f + | AST.GlobalDefinition f@(AST.Function {}) <- defs, + G.name f == AST.Name (U.stringToByteString "main") + ] + head mainFuncs + + let withModule test = do + case C.codegen simpleSum of + Left err -> + H.expectationFailure $ "Failed to generate module: " ++ show err + Right mod' -> test mod' + + H.it "should generate a valid module with main function" $ withModule $ \mod' -> do + let mainFunc = getMainFunc mod' + let (mainParams, isVarArg) = G.parameters mainFunc + length mainParams `H.shouldBe` 0 + isVarArg `H.shouldBe` False + + H.it "should have correct function type" $ withModule $ \mod' -> do + let mainFunc = getMainFunc mod' + G.returnType mainFunc `H.shouldBe` T.i32 + + H.it "should have the correct number of basic blocks" $ withModule $ \mod' -> do + let mainFunc = getMainFunc mod' + length (G.basicBlocks mainFunc) `H.shouldBe` 4 + + H.it "should have correct module name" $ withModule $ \mod' -> + AST.moduleName mod' `H.shouldBe` U.stringToByteString "sample.c" + + H.it "should not have internal, aka static, functions" $ withModule $ \mod' -> do + let internals = + [ f + | AST.GlobalDefinition f@(AST.Function {}) <- AST.moduleDefinitions mod', + G.linkage f == L.Internal + ] + length internals `H.shouldBe` 0 + + H.it "should have external linkage for main" $ withModule $ \mod' -> do + let mainFunc = getMainFunc mod' + G.linkage mainFunc `H.shouldBe` L.External + + H.it "should have conditional branch instruction" $ withModule $ \mod' -> do + let mainFunc = getMainFunc mod' + let blocks = G.basicBlocks mainFunc + let (G.BasicBlock _ _ terminator) = head blocks + case terminator of + AST.Do (AST.CondBr {}) -> return () + _ -> H.expectationFailure "Expected conditional branch instruction" + + H.context "when testing individual codegen functions" $ do + let wrapInFunction expr = + AT.Function + { AT.funcLoc = sampleLoc, + AT.funcName = "test", + AT.funcType = AT.TFunction (AT.TInt 32) [] False, + AT.funcParams = [], + AT.funcBody = + AT.Block + [ expr, + AT.Return sampleLoc (Just (AT.Lit sampleLoc (AT.LInt 0))) + ] + } + + H.describe "generateIf" $ do + H.it "should generate correct branch structure" $ do + let ifExpr = + AT.If + sampleLoc + (AT.Lit sampleLoc (AT.LBool True)) + (AT.Lit sampleLoc (AT.LInt 1)) + (Just (AT.Lit sampleLoc (AT.LInt 0))) + + let blocks = generateTestBlocks (wrapInFunction ifExpr) + length blocks `H.shouldBe` 4 + + H.describe "generateVar" $ do + H.it "should handle variable lookup correctly" $ do + let varExpr = AT.Var sampleLoc "x" (AT.TInt 32) + let varDecl = + AT.Declaration + sampleLoc + "x" + (AT.TInt 32) + (Just (AT.Lit sampleLoc (AT.LInt 42))) + + let blocks = generateTestBlocks (wrapInFunction (AT.Block [varDecl, varExpr])) + length blocks `H.shouldBe` 1 + + H.describe "generateBinaryOp" $ do + H.it "should generate correct arithmetic operations" $ do + let addExpr = + AT.Op + sampleLoc + AT.Add + (AT.Lit sampleLoc (AT.LInt 5)) + (AT.Lit sampleLoc (AT.LInt 3)) + + let blocks = generateTestBlocks (wrapInFunction addExpr) + let instrs = getInstructions blocks + any isAddInstr instrs `H.shouldBe` True + + H.describe "generateFunction" $ do + H.it "should create function with correct signature" $ do + let funcExpr = + AT.Function + sampleLoc + "test" + (AT.TFunction (AT.TInt 32) [] False) + [] + (AT.Block [AT.Return sampleLoc (Just (AT.Lit sampleLoc (AT.LInt 0)))]) + + let blocks = generateTestBlocks funcExpr + length blocks `H.shouldBe` 1 + + H.describe "generateDeclaration" $ do + H.it "should allocate and initialize variables" $ do + let declExpr = + AT.Declaration + sampleLoc + "x" + (AT.TInt 32) + (Just (AT.Lit sampleLoc (AT.LInt 42))) + + let blocks = generateTestBlocks (wrapInFunction declExpr) + let instrs = getInstructions blocks + any isAllocaInstr instrs `H.shouldBe` True + any isStoreInstr instrs `H.shouldBe` True + where + sampleLoc = AT.SrcLoc "test.c" 1 1 + + generateTestBlocks expr = case C.codegen testProg of + Right mod' -> concatMap G.basicBlocks $ getDefinitions mod' + Left _ -> [] + where + testProg = AT.Program [("test", expr)] [] "test.c" + + getDefinitions mod' = + [f | AST.GlobalDefinition f@(AST.Function {}) <- AST.moduleDefinitions mod'] + + getInstructions blocks = + [i | G.BasicBlock _ instrs _ <- blocks, i <- instrs] + + isAddInstr (AST.UnName _ AST.:= AST.Add {}) = True + isAddInstr _ = False + + isAllocaInstr (AST.UnName _ AST.:= AST.Alloca {}) = True + isAllocaInstr _ = False + + isStoreInstr (AST.Do (AST.Store {})) = True + isStoreInstr _ = False diff --git a/test/Misc/MiscSpec.hs b/test/Misc/MiscSpec.hs deleted file mode 100644 index 296fa44..0000000 --- a/test/Misc/MiscSpec.hs +++ /dev/null @@ -1,16 +0,0 @@ -module Misc.MiscSpec (spec) where - -import Misc (addOne) -import Test.Hspec (Spec, describe, it, shouldBe) -import Test.QuickCheck (property) - -spec :: Spec -spec = do - describe "addOne" $ do - it "adds one to a number" $ do - addOne 1 `shouldBe` 2 - it "adds one to a negative number" $ do - addOne (-1) `shouldBe` 0 - it "is the inverse of subtracting one" $ - property $ - \x -> addOne (x - 1) == x