diff --git a/src/main/kotlin/org/serenityos/jakt/annotations/BasicAnnotator.kt b/src/main/kotlin/org/serenityos/jakt/annotations/BasicAnnotator.kt index 8f41a283..33b75464 100644 --- a/src/main/kotlin/org/serenityos/jakt/annotations/BasicAnnotator.kt +++ b/src/main/kotlin/org/serenityos/jakt/annotations/BasicAnnotator.kt @@ -11,7 +11,7 @@ import com.intellij.refactoring.suggested.startOffset import org.serenityos.jakt.JaktTypes import org.serenityos.jakt.psi.ancestorOfType import org.serenityos.jakt.psi.api.* -import org.serenityos.jakt.psi.findChildrenOfType +import org.serenityos.jakt.psi.findChildOfType import org.serenityos.jakt.psi.reference.JaktPlainQualifierMixin import org.serenityos.jakt.psi.reference.exprAncestor import org.serenityos.jakt.psi.reference.hasNamespace @@ -61,14 +61,14 @@ object BasicAnnotator : JaktAnnotator(), DumbAware { is JaktImportBraceEntry -> element.identifier.highlight(Highlights.IMPORT_ENTRY) is JaktExternImport -> element.cSpecifier?.highlight(Highlights.KEYWORD_DECLARATION) is JaktImport -> { - val idents = element.findChildrenOfType(JaktTypes.IDENTIFIER) - idents.first().highlight(Highlights.IMPORT_MODULE) + element.importTarget.identifier?.highlight(Highlights.IMPORT_MODULE) - if (idents.size > 1) { + + if (element.keywordAs != null) { + element.findChildOfType(JaktTypes.IDENTIFIER)?.highlight(Highlights.IMPORT_ALIAS) // The 'as' keyword will be highlighted as an operator here without // the annotation element.keywordAs!!.highlight(Highlights.KEYWORD_IMPORT) - idents[1].highlight(Highlights.IMPORT_ALIAS) } } is JaktArgument -> { diff --git a/src/main/kotlin/org/serenityos/jakt/comptime/Interpreter.kt b/src/main/kotlin/org/serenityos/jakt/comptime/Interpreter.kt new file mode 100644 index 00000000..1e20f032 --- /dev/null +++ b/src/main/kotlin/org/serenityos/jakt/comptime/Interpreter.kt @@ -0,0 +1,922 @@ +package org.serenityos.jakt.comptime + +import com.intellij.openapi.util.TextRange +import com.intellij.psi.PsiElement +import com.intellij.refactoring.suggested.endOffset +import com.intellij.refactoring.suggested.startOffset +import org.serenityos.jakt.JaktFile +import org.serenityos.jakt.JaktTypes +import org.serenityos.jakt.psi.* +import org.serenityos.jakt.psi.api.* +import org.serenityos.jakt.psi.reference.hasNamespace +import org.serenityos.jakt.utils.unreachable + +class Interpreter(element: JaktPsiElement) { + var scope: Scope + + val stdout = StringBuilder() + val stderr = StringBuilder() + + init { + val outerScopes = element.ancestors().filter { + it is JaktFile || it is JaktFunction || it is JaktStructDeclaration || it is JaktBlock + }.map { + val scope = Scope(null) + if (it is JaktScope) { + for (decl in it.getDeclarations()) { + if (decl is JaktImportBraceEntry) + continue // TODO + + scope[decl.name ?: continue] = when (val result = evaluate(decl)) { + is ExecutionResult.Normal -> result.value + else -> continue + } + } + } + if (it is JaktFile) + initializeGlobalScope(scope) + scope + }.toList() + + for ((index, scope) in outerScopes.withIndex()) { + if (index == 0) + continue + scope.outer = outerScopes[index - 1] + } + + scope = outerScopes.last() + } + + fun pushScope(scope: Scope) { + this.scope = scope + } + + fun popScope() { + for (defer in scope.defers) + evaluate(defer) // TODO: Do something with the return value? + scope = scope.outer!! + } + + // A return value of null indicates the expression is not comptime. An exception being + // thrown indicates the expression is comptime, but it malformed in some way and cannot + // be evaluated. + fun evaluate(element: JaktPsiElement): ExecutionResult { + return when (element) { + /*** EXPRESSIONS ***/ + + is JaktMatchExpression -> TODO("comptime match expressions") + is JaktTryExpression -> { + when (val result = evaluate(element.expression ?: element.blockList[0])) { + is ExecutionResult.Throw -> if (element.catchKeyword != null) { + val pushedScope = if (element.catchDecl != null) { + val newScope = Scope(scope) + newScope[element.catchDecl!!.identifier.text] = result.value + pushScope(newScope) + true + } else false + + evaluate(element.blockList[1]) + + if (pushedScope) + popScope() + + } + !is ExecutionResult.Normal -> return result + else -> {} + } + + ExecutionResult.Normal(VoidValue) + } + is JaktLambdaExpression -> TODO("comptime lambdas") + is JaktAssignmentBinaryExpression -> { + val binaryOp = when { + element.plusEquals != null -> BinaryOperator.Add + element.minusEquals != null -> BinaryOperator.Subtract + element.asteriskEquals != null -> BinaryOperator.Multiply + element.slashEquals != null -> BinaryOperator.Divide + element.percentEquals != null -> BinaryOperator.Modulo + element.arithLeftShiftEquals != null -> BinaryOperator.ArithLeftShift + element.leftShiftEquals != null -> BinaryOperator.LeftShift + element.arithRightShiftEquals != null -> BinaryOperator.ArithRightShift + element.rightShiftEquals != null -> BinaryOperator.RightShift + else -> null + } + + val newValue = if (binaryOp != null) { + applyBinaryOperator(element.left, element.right!!, binaryOp).let { + when (it) { + is ExecutionResult.Normal -> it.value + is ExecutionResult.Yield -> error( + "Unexpected yield", + TextRange(element.left.startOffset, element.right!!.endOffset), + ) + else -> return it + } + } + } else evaluateNonYield(element.right!!) { return it } + + assign(element.left, newValue) + + ExecutionResult.Normal(VoidValue) + } + is JaktThisExpression -> TODO("comptime this expressions") + is JaktFieldAccessExpression -> TODO("comptime field access") + is JaktRangeExpression -> { + val (startExpr, endExpr) = when { + element.expressionList.size == 2 -> element.expressionList[0] to element.expressionList[1] + element.expressionList.isEmpty() -> null to null + element.expressionList[0].textRange.endOffset < element.dotDot.textRange.startOffset -> + element.expressionList[0] to null + else -> null to element.expressionList[0] + } + + val start = startExpr?.let { e -> evaluateNonYield(e) { return it } } ?: IntegerValue(0) + + val end = startExpr?.let { e -> evaluateNonYield(e) { return it } } ?: IntegerValue(Long.MAX_VALUE) + + if (start !is IntegerValue) + error("Expected range start value to be an integer", startExpr!!) + + if (end !is IntegerValue) + error("Expected range end value to be an integer", endExpr!!) + + ExecutionResult.Normal(RangeValue(start.value, end.value, isInclusive = false)) + } + is JaktLogicalOrBinaryExpression -> applyBinaryOperator( + element.left, + element.right!!, + BinaryOperator.LogicalOr + ) + is JaktLogicalAndBinaryExpression -> applyBinaryOperator( + element.left, + element.right!!, + BinaryOperator.LogicalAnd + ) + is JaktBitwiseOrBinaryExpression -> applyBinaryOperator( + element.left, + element.right!!, + BinaryOperator.BitwiseOr + ) + is JaktBitwiseXorBinaryExpression -> applyBinaryOperator( + element.left, + element.right!!, + BinaryOperator.BitwiseXor + ) + is JaktBitwiseAndBinaryExpression -> applyBinaryOperator( + element.left, + element.right!!, + BinaryOperator.BitwiseAnd + ) + is JaktRelationalBinaryExpression -> applyBinaryOperator( + element.left, + element.right!!, + when { + element.doubleEquals != null -> BinaryOperator.Equals + element.notEquals != null -> BinaryOperator.NotEquals + element.greaterThan != null -> BinaryOperator.GreaterThan + element.greaterThanEquals != null -> BinaryOperator.GreaterThanEq + element.lessThan != null -> BinaryOperator.LessThan + else -> BinaryOperator.LessThanEq + }, + ) + is JaktShiftBinaryExpression -> applyBinaryOperator( + element.left, + element.right!!, + when { + element.leftShift != null -> BinaryOperator.LeftShift + element.arithLeftShift != null -> BinaryOperator.ArithLeftShift + element.rightShift != null -> BinaryOperator.RightShift + else -> BinaryOperator.ArithRightShift + }, + ) + is JaktAddBinaryExpression -> applyBinaryOperator( + element.left, + element.right!!, + if (element.plus != null) BinaryOperator.Add else BinaryOperator.Subtract, + ) + is JaktMultiplyBinaryExpression -> applyBinaryOperator( + element.left, + element.right!!, + when { + element.asterisk != null -> BinaryOperator.Multiply + element.slash != null -> BinaryOperator.Divide + else -> BinaryOperator.Modulo + }, + ) + is JaktCastExpression -> TODO("comptime casts") + is JaktIsExpression -> TODO("comptime is expressions") + is JaktUnaryExpression -> applyUnaryOperator( + element.expression, + when { + element.minus != null -> UnaryOperator.Minus + element.keywordNot != null -> UnaryOperator.Not + element.tilde != null -> UnaryOperator.BitwiseNot + element.ampersand != null -> when { + element.rawKeyword != null -> UnaryOperator.RawReference + element.mutKeyword != null -> UnaryOperator.MutReference + else -> UnaryOperator.Reference + } + element.asterisk != null -> UnaryOperator.Dereference + element.exclamationPoint != null -> UnaryOperator.Unwrap + else -> { + val plusPlus = element.findChildOfType(JaktTypes.PLUS_PLUS) + if (plusPlus != null) { + if (plusPlus.startOffset < element.expression.startOffset) { + UnaryOperator.PrefixIncrement + } else UnaryOperator.PostfixIncrement + } else { + val minusMinus = element.findChildOfType(JaktTypes.MINUS_MINUS)!! + if (minusMinus.startOffset < element.expression.startOffset) { + UnaryOperator.PrefixDecrement + } else UnaryOperator.PostfixDecrement + } + } + } + ) + is JaktBooleanLiteral -> ExecutionResult.Normal(BoolValue(element.trueKeyword != null)) + is JaktNumericLiteral -> { + element.binaryLiteral?.let { + return ExecutionResult.Normal(IntegerValue(it.text.drop(2).toLong(2))) + } + + element.octalLiteral?.let { + return ExecutionResult.Normal(IntegerValue(it.text.drop(2).toLong(8))) + } + + element.hexLiteral?.let { + return ExecutionResult.Normal(IntegerValue(it.text.drop(2).toLong(16))) + } + + val decimalText = element.decimalLiteral!!.text + return ExecutionResult.Normal( + if ("." in decimalText) { + FloatValue(decimalText.toDouble()) + } else IntegerValue(decimalText.toLong(10)) + ) + } + is JaktLiteral -> { + element.byteCharLiteral?.let { + return ExecutionResult.Normal(ByteCharValue(it.text[2].code.toByte())) + } + + element.charLiteral?.let { + return ExecutionResult.Normal(CharValue(it.text[1].code.toChar())) + } + + ExecutionResult.Normal(StringValue(element.stringLiteral!!.text.drop(1).dropLast(1))) + } + is JaktAccessExpression -> { + val target = evaluateNonYield(element.expression) { return it } + + if (element.dotQuestionMark != null) + TODO() + + if (element.decimalLiteral != null) { + if (target !is TupleValue) + error("Invalid tuple index into non-tuple value", element.decimalLiteral!!) + + ExecutionResult.Normal(target.values[element.decimalLiteral!!.text.toInt()]) + } else { + val value = target[element.identifier!!.text] ?: error("Unknown field ${element.identifier!!.text}") + ExecutionResult.Normal(value) + } + } + is JaktIndexedAccessExpression -> { + val target = evaluateNonYield(element.expressionList[0]) { return it } + val value = evaluateNonYield(element.expressionList[1]) { return it } + + if (target !is ArrayValue) + error("Unexpected index into non-array value", element.expressionList[0]) + + when (value) { + is RangeValue -> ExecutionResult.Normal(ArraySlice(target, value.range)) + is IntegerValue -> { + val result = target.values.getOrNull(value.value.toInt()) + ?: error("Index ${value.value} out-of-range for array of length ${target.values.size}") + ExecutionResult.Normal(result) + } + else -> error("Expected integer or range in array indexing expression", element.expressionList[1]) + } + } + is JaktPlainQualifierExpression -> { + val qualifier = element.plainQualifier + + val parts = generateSequence(qualifier) { it.plainQualifier } + .map { it to it.identifier.text } + .toMutableList() + .asReversed() + + var value: Value? = null + var currScope: Scope? = scope + + while (currScope != null) { + if (parts[0].second in currScope) { + value = scope[parts[0].second] + parts.removeFirst() + break + } + + currScope = currScope.outer + } + + if (value == null) { + val type = if (parts.size > 1) "qualifier" else "identifier" + error("Unknown $type \"${parts[0].second}\"", parts[0].first) + } + + for (part in parts) { + if (part.second !in value!!) + error("\"${value.typeName()}\" has no member named ${part.second}", part.first) + + value = value[part.second]!! + } + + ExecutionResult.Normal(value!!) + } + is JaktCallExpression -> { + val target = evaluateNonYield(element.expression) { return it } + + if (target !is FunctionValue) + error("\"${target.typeName()}\" is not callable", element.expression) + + val args = element.argumentList.argumentList.map { arg -> + evaluateNonYield(arg.expression) { return it } + } + + if (args.size !in target.validParamCount) { + error( + "Expected between ${target.validParamCount.first} and ${target.validParamCount.last} " + + "arguments, but found ${args.size} arguments", + element.argumentList + ) + } + + val thisValue = when (val expr = element.expression) { + is JaktAccessExpression -> evaluateNonYield(expr.expression) { return it } + is JaktFieldAccessExpression -> (scope as FunctionScope).thisBinding!! + is JaktIndexedAccessExpression -> evaluateNonYield(expr.expressionList.first()) { return it } + else -> null + } + + target.call(this, thisValue, args) + } + is JaktArrayExpression -> { + element.elementsArrayBody?.let { body -> + val array = ArrayValue(body.expressionList.map { expr -> + evaluateNonYield(expr) { return it } + }.toMutableList()) + + return ExecutionResult.Normal(array) + } + + val body = element.sizedArrayBody!! + + val value = evaluateNonYield(body.expressionList[0]) { return it } + val size = evaluateNonYield(body.expressionList[1]) { return it } + + if (size !is IntegerValue) + error("Array size initializer must be an integer", body.expressionList[1]) + + ExecutionResult.Normal(ArrayValue((0 until size.value).map { value }.toMutableList())) + } + is JaktDictionaryExpression -> ExecutionResult.Normal( + DictionaryValue( + element.dictionaryElementList.associate { pair -> + val key = evaluateNonYield(pair.expressionList[0]) { return it } + val value = evaluateNonYield(pair.expressionList[1]) { return it } + key to value + }.toMutableMap() + ) + ) + is JaktSetExpression -> ExecutionResult.Normal(SetValue(element.expressionList.map { e -> + evaluateNonYield(e) { return it } + }.toMutableSet())) + is JaktTupleExpression -> ExecutionResult.Normal(TupleValue(element.expressionList.map { e -> + evaluateNonYield(e) { return it } + }.toMutableList())) + is JaktParenExpression -> evaluate(element.expression!!) + + /*** STATEMENTS ***/ + + is JaktExpressionStatement -> { + evaluateNonYield(element.expression) { return it } + ExecutionResult.Normal(VoidValue) + } + is JaktReturnStatement -> ExecutionResult.Return(element.expression?.let {e -> + evaluateNonYield(e) { return it } + } ?: VoidValue) + is JaktThrowStatement -> ExecutionResult.Throw(evaluateNonYield(element.expression) { return it }) + is JaktDeferStatement -> { + scope.defers.add(element.statement) + ExecutionResult.Normal(VoidValue) + } + is JaktIfStatement -> { + val condition = evaluateNonYield(element.expression) { return it } + + if (condition !is BoolValue) + error("Expected bool", element.expression) + + if (condition.value) { + evaluateNonYield(element.block) { return it } + } else if (element.ifStatement != null) { + evaluateNonYield(element.ifStatement!!) { return it } + } else if (element.elseBlock != null) { + evaluateNonYield(element.elseBlock!!) { return it } + } + + ExecutionResult.Normal(VoidValue) + } + is JaktWhileStatement -> { + while (true) { + val exprResult = evaluateNonYield(element.expression) { return it } + if (exprResult !is BoolValue) + error("Expected bool value, found ${exprResult.typeName()}", element.expression) + + if (!exprResult.value) + break + + when (val blockResult = evaluate(element.block)) { + is ExecutionResult.Break -> break + is ExecutionResult.Continue, is ExecutionResult.Normal -> {} + else -> return blockResult + } + } + + ExecutionResult.Normal(VoidValue) + } + is JaktLoopStatement -> { + while (true) { + when (val blockResult = evaluate(element.block)) { + is ExecutionResult.Break -> break + is ExecutionResult.Continue, is ExecutionResult.Normal -> {} + else -> return blockResult + } + } + + ExecutionResult.Normal(VoidValue) + } + is JaktForStatement -> TODO("comptime for statements") + is JaktVariableDeclarationStatement -> { + if (element.parenOpen != null) { + error( + "Destructuring variable assignments are not supported", + TextRange(element.parenOpen!!.startOffset, element.parenClose!!.endOffset), + ) + } + + val rhs = evaluateNonYield(element.expression) { return it } + assign(element.variableDeclList[0].name!!, rhs, initialize = true) + + ExecutionResult.Normal(VoidValue) + } + is JaktGuardStatement -> { + val condition = evaluateNonYield(element.expression) { return it } + if (condition !is BoolValue) + error("Expected bool value, found ${condition.typeName()}", element.expression) + + if (condition.value) { + when (val result = evaluate(element.block)) { + is ExecutionResult.Normal -> error("Unexpected fallthrough from guard block", element.block) + else -> return result + } + } + + ExecutionResult.Normal(VoidValue) + } + is JaktYieldStatement -> ExecutionResult.Yield(evaluateNonYield(element.expression) { return it }) + is JaktBreakStatement -> ExecutionResult.Break + is JaktContinueStatement -> ExecutionResult.Continue + is JaktUnsafeStatement -> error("Cannot evaluate unsafe blocks at comptime", element) + is JaktInlineCppStatement -> error("Cannot evaluate inline cpp blocks at comptime") + is JaktBlock -> { + pushScope(Scope(scope)) + try { + element.statementList.forEach { + val result = evaluate(it) + if (result !is ExecutionResult.Normal) + return result + } + ExecutionResult.Normal(VoidValue) + } finally { + popScope() + } + } + is JaktFunction -> { + val parameters = element.parameterList.parameterList.map { param -> + val default = param.expression?.let { e -> + evaluateNonYield(e) { return it } + } + + FunctionValue.Parameter(param.identifier.text, default) + } + + val target = element.block ?: element.expression ?: return ExecutionResult.Normal(VoidValue) + ExecutionResult.Normal(UserFunctionValue(parameters, target)) + } + + // Ignored declarations (hoisted at scope initialization) + is JaktImport -> ExecutionResult.Normal(VoidValue) + + else -> error("${element::class.simpleName} is not supported at comptime") + } + } + + private fun applyBinaryOperator( + lhsExpr: JaktExpression, + rhsExpr: JaktExpression, + op: BinaryOperator + ): ExecutionResult { + if (op == BinaryOperator.LogicalOr || op == BinaryOperator.LogicalAnd) { + val shortCircuitValue = op == BinaryOperator.LogicalOr + + val lhsValue = evaluateNonYield(lhsExpr) { return it } + if (lhsValue !is BoolValue) + error("Expected bool, found ${lhsValue.typeName()}", lhsExpr) + + if (lhsValue.value == shortCircuitValue) + return ExecutionResult.Normal(BoolValue(shortCircuitValue)) + + val rhsValue = evaluateNonYield(rhsExpr) { return it } + if (rhsValue !is BoolValue) + error("Expected bool, found ${rhsValue.typeName()}", rhsExpr) + + return ExecutionResult.Normal(rhsValue) + } + + val lhsValue = evaluateNonYield(lhsExpr) { return it } + val rhsValue = evaluateNonYield(rhsExpr) { return it } + + fun incompatError(): Nothing { + error( + "Incompatible types \"${lhsValue.typeName()}\" and \"${rhsValue.typeName()}\" for operator ${op.op}", + TextRange(lhsExpr.textRange.startOffset, rhsExpr.textRange.endOffset) + ) + } + + if (lhsValue.typeName() != rhsValue.typeName()) + incompatError() + + // TODO: Separator integer types + + val value = when (op) { + BinaryOperator.BitwiseOr, + BinaryOperator.BitwiseXor, + BinaryOperator.BitwiseAnd, + BinaryOperator.LeftShift, + BinaryOperator.RightShift, + BinaryOperator.ArithLeftShift, + BinaryOperator.ArithRightShift -> { + if (lhsValue !is IntegerValue || rhsValue !is IntegerValue) + incompatError() + + val value = when (op) { + BinaryOperator.BitwiseOr -> lhsValue.value or rhsValue.value + BinaryOperator.BitwiseXor -> lhsValue.value xor rhsValue.value + BinaryOperator.BitwiseAnd -> lhsValue.value and rhsValue.value + BinaryOperator.LeftShift, BinaryOperator.ArithLeftShift -> lhsValue.value shl rhsValue.value.toInt() + BinaryOperator.RightShift, BinaryOperator.ArithRightShift -> lhsValue.value shr rhsValue.value.toInt() + else -> unreachable() + } + + IntegerValue(value) + } + BinaryOperator.Add, + BinaryOperator.Subtract, + BinaryOperator.Multiply, + BinaryOperator.Divide, + BinaryOperator.Modulo -> { + val lhsNum = when (lhsValue) { + is IntegerValue -> lhsValue.value.toDouble() + is FloatValue -> lhsValue.value + else -> incompatError() + } + + val rhsNum = when (rhsValue) { + is IntegerValue -> rhsValue.value.toDouble() + is FloatValue -> rhsValue.value + else -> incompatError() + } + + val result = when (op) { + BinaryOperator.Add -> lhsNum + rhsNum + BinaryOperator.Subtract -> lhsNum - rhsNum + BinaryOperator.Multiply -> lhsNum * rhsNum + BinaryOperator.Divide -> lhsNum / rhsNum + BinaryOperator.Modulo -> lhsNum % rhsNum + else -> unreachable() + } + + if (lhsValue is IntegerValue) { + IntegerValue(result.toLong()) + } else FloatValue(result) + } + BinaryOperator.Equals -> when (lhsValue) { + is BoolValue -> BoolValue(lhsValue.value == (rhsValue as BoolValue).value) + is IntegerValue -> BoolValue(lhsValue.value == (rhsValue as IntegerValue).value) + is FloatValue -> BoolValue(lhsValue.value == (rhsValue as FloatValue).value) + is CharValue -> BoolValue(lhsValue.value == (rhsValue as CharValue).value) + is ByteCharValue -> BoolValue(lhsValue.value == (rhsValue as ByteCharValue).value) + is StringValue -> BoolValue(lhsValue.value == (rhsValue as StringValue).value) + else -> incompatError() + } + BinaryOperator.NotEquals -> when (lhsValue) { + is BoolValue -> BoolValue(lhsValue.value != (rhsValue as BoolValue).value) + is IntegerValue -> BoolValue(lhsValue.value != (rhsValue as IntegerValue).value) + is FloatValue -> BoolValue(lhsValue.value != (rhsValue as FloatValue).value) + is CharValue -> BoolValue(lhsValue.value != (rhsValue as CharValue).value) + is ByteCharValue -> BoolValue(lhsValue.value != (rhsValue as ByteCharValue).value) + is StringValue -> BoolValue(lhsValue.value != (rhsValue as StringValue).value) + else -> incompatError() + } + BinaryOperator.GreaterThan -> when (lhsValue) { + is BoolValue -> BoolValue(lhsValue.value > (rhsValue as BoolValue).value) + is IntegerValue -> BoolValue(lhsValue.value > (rhsValue as IntegerValue).value) + is FloatValue -> BoolValue(lhsValue.value > (rhsValue as FloatValue).value) + is CharValue -> BoolValue(lhsValue.value > (rhsValue as CharValue).value) + is ByteCharValue -> BoolValue(lhsValue.value > (rhsValue as ByteCharValue).value) + is StringValue -> BoolValue(lhsValue.value > (rhsValue as StringValue).value) + else -> incompatError() + } + BinaryOperator.GreaterThanEq -> when (lhsValue) { + is BoolValue -> BoolValue(lhsValue.value >= (rhsValue as BoolValue).value) + is IntegerValue -> BoolValue(lhsValue.value >= (rhsValue as IntegerValue).value) + is FloatValue -> BoolValue(lhsValue.value >= (rhsValue as FloatValue).value) + is CharValue -> BoolValue(lhsValue.value >= (rhsValue as CharValue).value) + is ByteCharValue -> BoolValue(lhsValue.value >= (rhsValue as ByteCharValue).value) + is StringValue -> BoolValue(lhsValue.value >= (rhsValue as StringValue).value) + else -> incompatError() + } + BinaryOperator.LessThan -> when (lhsValue) { + is BoolValue -> BoolValue(lhsValue.value < (rhsValue as BoolValue).value) + is IntegerValue -> BoolValue(lhsValue.value < (rhsValue as IntegerValue).value) + is FloatValue -> BoolValue(lhsValue.value < (rhsValue as FloatValue).value) + is CharValue -> BoolValue(lhsValue.value < (rhsValue as CharValue).value) + is ByteCharValue -> BoolValue(lhsValue.value < (rhsValue as ByteCharValue).value) + is StringValue -> BoolValue(lhsValue.value < (rhsValue as StringValue).value) + else -> incompatError() + } + BinaryOperator.LessThanEq -> when (lhsValue) { + is BoolValue -> BoolValue(lhsValue.value <= (rhsValue as BoolValue).value) + is IntegerValue -> BoolValue(lhsValue.value <= (rhsValue as IntegerValue).value) + is FloatValue -> BoolValue(lhsValue.value <= (rhsValue as FloatValue).value) + is CharValue -> BoolValue(lhsValue.value <= (rhsValue as CharValue).value) + is ByteCharValue -> BoolValue(lhsValue.value <= (rhsValue as ByteCharValue).value) + is StringValue -> BoolValue(lhsValue.value <= (rhsValue as StringValue).value) + else -> incompatError() + } + else -> unreachable() + } + + return ExecutionResult.Normal(value) + } + + private fun applyUnaryOperator(expr: JaktExpression, op: UnaryOperator): ExecutionResult { + return when (op) { + UnaryOperator.PostfixIncrement, UnaryOperator.PostfixDecrement, + UnaryOperator.PrefixIncrement, UnaryOperator.PrefixDecrement -> { + val isPrefix = op == UnaryOperator.PrefixIncrement || op == UnaryOperator.PrefixDecrement + val isIncrement = op == UnaryOperator.PrefixIncrement || op == UnaryOperator.PostfixIncrement + + val delta = if (isIncrement) 1 else -1 + + val currValue = evaluateNonYield(expr) { return it } + val newValue = when (currValue) { + is IntegerValue -> IntegerValue(currValue.value + delta) + is FloatValue -> FloatValue(currValue.value + delta) + else -> error("Invalid type ${currValue.typeName()} for numeric prefix operator '${op.op}'", expr) + } + + assign(expr, newValue)?.let { return it } + + ExecutionResult.Normal(if (isPrefix) newValue else currValue) + } + UnaryOperator.Minus -> { + val value = evaluateNonYield(expr) { return it } + when (value) { + is IntegerValue -> IntegerValue(-value.value) + is FloatValue -> FloatValue(-value.value) + else -> error("Invalid type ${value.typeName()} for unary integer '-' operator", expr) + }.let(ExecutionResult::Normal) + } + UnaryOperator.Not -> { + val value = evaluateNonYield(expr) { return it } + if (value is BoolValue) { + ExecutionResult.Normal(BoolValue(!value.value)) + } else { + error("Invalid type ${value.typeName()} for boolean 'not' operator", expr) + } + } + UnaryOperator.BitwiseNot -> { + val value = evaluateNonYield(expr) { return it } + if (value is IntegerValue) { + ExecutionResult.Normal(IntegerValue(value.value.inv())) + } else { + error("Invalid type ${value.typeName()} for integer '~' operator", expr) + } + } + UnaryOperator.Reference -> TODO("comptime unary reference operator") + UnaryOperator.RawReference -> TODO("comptime unary reference operator") + UnaryOperator.MutReference -> TODO("comptime unary reference operator") + UnaryOperator.Dereference -> TODO("comptime unary dereference operator") + UnaryOperator.Unwrap -> { + val value = evaluateNonYield(expr) { return it } + + when (value) { + is OptionalValue -> if (value.value == null) { + error("Attempt to unwrap empty optional value", expr) + } else { + ExecutionResult.Normal(value.value) + } + else -> error("Invalid type ${value.typeName()} for optional unwrap operator '!'", expr) + } + } + } + } + + private fun assign(expr: JaktExpression, value: Value): ExecutionResult? { + when (expr) { + is JaktPlainQualifierExpression -> { + if (expr.plainQualifier.hasNamespace) + error("Invalid assignment target", expr) + + val name = expr.plainQualifier.name!! + if (!assign(name, value, initialize = false)) + error("Unknown identifier \"$name\"", expr) + } + is JaktIndexedAccessExpression -> { + val target = evaluateNonYield(expr.expressionList[0]) { return it } + + if (target !is ArrayValue) + error("Expected array, found ${target.typeName()}", expr.expressionList[0]) + + val index = evaluateNonYield(expr.expressionList[1]!!) { return it } + + if (index !is IntegerValue) + error("Expected integer, found ${index.typeName()}", expr.expressionList[1]!!) + + if (index.value.toInt() > target.values.size) + error( + "Out-of-bounds assignment to array of length ${target.values.size} with index ${index.value}", + expr + ) + + target.values[index.value.toInt()] = value + } + is JaktAccessExpression -> { + val target = evaluateNonYield(expr.expression) { return it } + + if (expr.decimalLiteral != null) { + if (target !is TupleValue) + error("Expected tuple, found ${target.typeName()}", expr.expression) + + val index = expr.decimalLiteral!!.text.toInt() + if (index > target.values.size) + error("Cannot assign to index $index of tuple of length ${target.values.size}") + + target.values[index] = value + } else { + target[expr.identifier!!.text] = value + } + } + else -> error("Invalid assignment target", expr) + } + + return null + } + + private fun assign(name: String, value: Value, initialize: Boolean): Boolean { + // TODO: Ensure bindings already exists in the scope + + if (initialize) { + scope[name] = value + return true + } + + var currScope: Scope? = scope + while (currScope != null) { + if (name in currScope) { + currScope[name] = value + return true + } + + currScope = currScope.outer + } + + return false + } + + private fun initializeGlobalScope(scope: Scope) { + scope["String"] = StringStruct + scope["StringBuilder"] = StringBuilderStruct + scope["Error"] = ErrorStruct + scope["File"] = FileStruct + scope["___jakt_get_target_triple_string"] = jaktGetTargetTripleStringFunction + scope["abort"] = abortFunction + scope["format"] = FormatFunction + scope["print"] = PrintFunction + scope["println"] = PrintlnFunction + scope["eprint"] = EprintFunction + scope["eprintln"] = EprintlnFunction + } + + private inline fun evaluateNonYield(element: JaktPsiElement, returner: (ExecutionResult) -> Nothing): Value { + when (val result = evaluate(element)) { + is ExecutionResult.Normal -> return result.value + is ExecutionResult.Yield -> error("Unexpected yield", element.descendantOfType() ?: element) + else -> returner(result) + } + } + + fun error(message: String, element: PsiElement): Nothing = error(message, element.textRange) + + fun error(message: String, range: TextRange): Nothing = throw InterpreterException(message, range) + + enum class UnaryOperator(val op: String) { + PrefixIncrement("++"), + PrefixDecrement("--"), + PostfixIncrement("++"), + PostfixDecrement("--"), + Minus("-"), + Not("not"), + BitwiseNot("~"), + Reference("&"), + RawReference("&raw"), + MutReference("&mut"), + Dereference("*"), + Unwrap("!"), + } + + enum class BinaryOperator(val op: String) { + LogicalOr("or"), + LogicalAnd("and"), + BitwiseOr("|"), + BitwiseXor("^"), + BitwiseAnd("&"), + LeftShift("<<"), + RightShift(">>"), + ArithLeftShift("<<<"), + ArithRightShift(">>>"), + Equals("=="), + NotEquals("!="), + GreaterThan(">"), + GreaterThanEq(">="), + LessThan("<"), + LessThanEq("<="), + Add("+"), + Subtract("-"), + Multiply("*"), + Divide("/"), + Modulo("%"), + } + + open class Scope(var outer: Scope?) { + val defers = mutableListOf() + + protected val bindings = mutableMapOf() + + operator fun contains(name: String) = name in bindings + + operator fun get(name: String): Value? = bindings[name] ?: outer?.get(name) + + operator fun set(name: String, value: Value) { + bindings[name] = value + } + } + + class FunctionScope(outer: Scope?, val thisBinding: Value?) : Scope(outer) { + fun argument(name: String) = bindings[name]!! + } + + sealed interface ExecutionResult { + class Return(val value: Value) : ExecutionResult + + class Yield(val value: Value) : ExecutionResult + + class Throw(val value: Value) : ExecutionResult + + class Normal(val value: Value) : ExecutionResult + + object Continue : ExecutionResult + + object Break : ExecutionResult + } + + data class Result( + val value: Value?, + val stdout: String, + val stderr: String, + ) + + companion object { + fun evaluate(element: JaktPsiElement): Result { + val interpreter = Interpreter(element) + + val value = when (val result = interpreter.evaluate(element)) { + is ExecutionResult.Normal -> result.value + else -> null + } + + return Result(value, interpreter.stdout.toString(), interpreter.stderr.toString()) + } + } +} diff --git a/src/main/kotlin/org/serenityos/jakt/comptime/InterpreterException.kt b/src/main/kotlin/org/serenityos/jakt/comptime/InterpreterException.kt new file mode 100644 index 00000000..aeb75284 --- /dev/null +++ b/src/main/kotlin/org/serenityos/jakt/comptime/InterpreterException.kt @@ -0,0 +1,5 @@ +package org.serenityos.jakt.comptime + +import com.intellij.openapi.util.TextRange + +class InterpreterException(message: String, val range: TextRange) : Exception(message) diff --git a/src/main/kotlin/org/serenityos/jakt/comptime/ShowComptimeValueAction.kt b/src/main/kotlin/org/serenityos/jakt/comptime/ShowComptimeValueAction.kt new file mode 100644 index 00000000..b3ca8908 --- /dev/null +++ b/src/main/kotlin/org/serenityos/jakt/comptime/ShowComptimeValueAction.kt @@ -0,0 +1,135 @@ +package org.serenityos.jakt.comptime + +import com.intellij.codeInsight.documentation.DocumentationComponent +import com.intellij.codeInsight.documentation.DocumentationHtmlUtil +import com.intellij.codeInsight.hint.HintManagerImpl +import com.intellij.lang.documentation.DocumentationMarkup +import com.intellij.openapi.actionSystem.AnAction +import com.intellij.openapi.actionSystem.AnActionEvent +import com.intellij.openapi.actionSystem.CommonDataKeys +import com.intellij.openapi.editor.colors.EditorColorsManager +import com.intellij.openapi.editor.impl.EditorCssFontResolver +import com.intellij.openapi.util.registry.Registry +import com.intellij.openapi.util.text.StringUtil +import com.intellij.psi.PsiElement +import com.intellij.ui.LightweightHint +import com.intellij.ui.scale.JBUIScale +import com.intellij.util.ui.HTMLEditorKitBuilder +import com.intellij.util.ui.JBUI +import com.intellij.util.ui.UIUtil +import org.serenityos.jakt.psi.JaktPsiElement +import java.awt.Font +import java.awt.Point +import javax.swing.JEditorPane +import javax.swing.text.StyledDocument + +class ShowComptimeValueAction : AnAction() { + override fun update(e: AnActionEvent) { + e.presentation.isEnabledAndVisible = e.element?.getComptimeTargetElement() != null + } + + override fun actionPerformed(e: AnActionEvent) { + val element = e.element?.getComptimeTargetElement() ?: return + val hint = LightweightHint(ComptimePopup(element)) + val editor = e.getData(CommonDataKeys.EDITOR)!! + + HintManagerImpl.getInstanceImpl().showEditorHint( + hint, + editor, + HintManagerImpl.getHintPosition(hint, editor, editor.caretModel.logicalPosition, 0), + 0, + 0, + false, + ) + } + + private val AnActionEvent.element: PsiElement? + get() { + val file = dataContext.getData(CommonDataKeys.PSI_FILE) ?: return null + val editor = dataContext.getData(CommonDataKeys.EDITOR) ?: return null + return file.findElementAt(editor.caretModel.offset) + } + + private operator fun Point.plus(other: Point) = Point(x + other.x, y + other.y) + + @Suppress("UnstableApiUsage") + private class ComptimePopup(element: JaktPsiElement) : JEditorPane() { + init { + val editorKit = HTMLEditorKitBuilder() + .withFontResolver(EditorCssFontResolver.getGlobalInstance()) + .build() + + DocumentationHtmlUtil.addDocumentationPaneDefaultCssRules(editorKit) + + // Overwrite a rule added by the above call: "html { padding-bottom: 8pm }" + editorKit.styleSheet.addRule("html { padding-bottom: 0px; }") + + this.editorKit = editorKit + border = JBUI.Borders.empty() + + // DocumentationEditorPane::applyFontProps + if (document is StyledDocument) { + val fontName = if (Registry.`is`("documentation.component.editor.font")) { + EditorColorsManager.getInstance().globalScheme.editorFontName + } else font.fontName + + font = UIUtil.getFontWithFallback( + fontName, + Font.PLAIN, + JBUIScale.scale(DocumentationComponent.getQuickDocFontSize().size), + ) + } + + buildText(element) + } + + private fun buildText(element: JaktPsiElement) { + val result = try { + Result.success(element.performComptimeEvaluation()) + } catch (e: Throwable) { + Result.failure(e) + } + + val builder = StringBuilder() + + builder.append("
")
+            // TODO: Render the text
+            builder.append(StringUtil.escapeXmlEntities(element.text))
+            builder.append("
") + if (result.isSuccess) { + val output = result.getOrThrow() + + builder.append(DocumentationMarkup.SECTIONS_START) + builder.append(DocumentationMarkup.SECTION_HEADER_START) + builder.append("Output") + builder.append(DocumentationMarkup.SECTION_SEPARATOR) + if (output.value == null) { + builder.append("Unable to evaluate element") + } else { + builder.append(StringUtil.escapeXmlEntities(output.value.toString())) + } + builder.append(DocumentationMarkup.SECTION_END) + + if (output.stdout.isNotEmpty()) { + builder.append(DocumentationMarkup.SECTION_HEADER_START) + builder.append("stdout") + builder.append(DocumentationMarkup.SECTION_SEPARATOR) + builder.append(StringUtil.escapeXmlEntities(output.stdout).replace("\n", "
")) + builder.append(DocumentationMarkup.SECTION_END) + } + + if (output.stderr.isNotEmpty()) { + builder.append(DocumentationMarkup.SECTION_HEADER_START) + builder.append("stderr") + builder.append(DocumentationMarkup.SECTION_SEPARATOR) + builder.append(StringUtil.escapeXmlEntities(output.stderr).replace("\n", "
")) + builder.append(DocumentationMarkup.SECTION_END) + } + } else { + builder.append(StringUtil.escapeXmlEntities("Internal error: ${result.exceptionOrNull()!!.message}")) + } + + text = builder.toString() + } + } +} diff --git a/src/main/kotlin/org/serenityos/jakt/comptime/Value.kt b/src/main/kotlin/org/serenityos/jakt/comptime/Value.kt new file mode 100644 index 00000000..c9a6a8a8 --- /dev/null +++ b/src/main/kotlin/org/serenityos/jakt/comptime/Value.kt @@ -0,0 +1,109 @@ +package org.serenityos.jakt.comptime + +import org.serenityos.jakt.psi.JaktPsiElement +import org.serenityos.jakt.psi.api.JaktExpression + +// As with everything else in the plugin, we are very lenient when it comes to types. +// All integers are treated as i64, just to make my life easier. Similarly, all float +// types are treated as f64. If we produce a value for something that doesn't actually +// compile, we'll get an IDE error for it anyways +sealed class Value { + private val fields = mutableMapOf() + + abstract fun typeName(): String + + open operator fun contains(name: String) = name in fields + + open operator fun get(name: String) = fields[name] + + open operator fun set(name: String, value: Value) { + fields[name] = value + } +} + +interface PrimitiveValue { + val value: Any +} + +object VoidValue : Value() { + override fun typeName() = "void" + override fun toString() = "void" +} + +data class BoolValue(override val value: Boolean) : Value(), PrimitiveValue { + override fun typeName() = "bool" + override fun toString() = value.toString() +} + +data class IntegerValue(override val value: Long) : Value(), PrimitiveValue { + override fun typeName() = "i64" + override fun toString() = value.toString() +} + +data class FloatValue(override val value: Double) : Value(), PrimitiveValue { + override fun typeName() = "f64" + override fun toString() = value.toString() +} + +data class CharValue(override val value: Char) : Value(), PrimitiveValue { + override fun typeName() = "c_char" + override fun toString() = "'$value'" +} + +data class ByteCharValue(override val value: Byte) : Value(), PrimitiveValue { + override fun typeName() = "u8" + override fun toString() = "b'$value'" +} + +data class TupleValue(val values: MutableList) : Value() { + override fun typeName() = "Tuple" + override fun toString() = values.joinToString(prefix = "(", postfix = ")") +} + +abstract class FunctionValue(private val minParamCount: Int, private val maxParamCount: Int) : Value() { + val validParamCount: IntRange + get() = minParamCount..maxParamCount + + init { + require(minParamCount <= maxParamCount) + } + + override fun typeName() = "function" + + constructor(parameters: List) : this(parameters.count { it.default == null }, parameters.size) + + abstract fun call(interpreter: Interpreter, thisValue: Value?, arguments: List): Interpreter.ExecutionResult + + data class Parameter(val name: String, val default: Value? = null) +} + +class UserFunctionValue( + private val parameters: List, + val body: JaktPsiElement /* JaktBlock | JaktExpression */, // TODO: Storing PSI is bad, right? +) : FunctionValue(parameters) { + override fun call(interpreter: Interpreter, thisValue: Value?, arguments: List): Interpreter.ExecutionResult { + val newScope = Interpreter.FunctionScope(interpreter.scope, thisValue) + + for ((index, param) in parameters.withIndex()) { + if (index <= arguments.lastIndex) { + newScope[param.name] = arguments[index] + } else { + check(param.default != null) + newScope[param.name] = param.default + } + } + + interpreter.pushScope(newScope) + + return when (val result = interpreter.evaluate(body)) { + is Interpreter.ExecutionResult.Normal -> if (body is JaktExpression) { + Interpreter.ExecutionResult.Normal(result.value) + } else Interpreter.ExecutionResult.Normal(VoidValue) + is Interpreter.ExecutionResult.Return -> Interpreter.ExecutionResult.Normal(result.value) + is Interpreter.ExecutionResult.Throw -> result + else -> interpreter.error("Unexpected control flow", body) + }.also { + interpreter.popScope() + } + } +} diff --git a/src/main/kotlin/org/serenityos/jakt/comptime/builtins.kt b/src/main/kotlin/org/serenityos/jakt/comptime/builtins.kt new file mode 100644 index 00000000..1c4133d1 --- /dev/null +++ b/src/main/kotlin/org/serenityos/jakt/comptime/builtins.kt @@ -0,0 +1,764 @@ +package org.serenityos.jakt.comptime + +import org.serenityos.jakt.comptime.Interpreter.ExecutionResult +import org.serenityos.jakt.project.JaktProjectListener +import java.io.File + +class BuiltinFunction( + parameterCount: Int, + private val func: (Value?, List) -> Value, +) : FunctionValue(parameterCount, parameterCount) { + override fun call(interpreter: Interpreter, thisValue: Value?, arguments: List): ExecutionResult { + return ExecutionResult.Normal(func(thisValue, arguments)) + } +} + +data class OptionalValue(val value: Value?) : Value() { + init { + this["has_value"] = hasValue + this["value"] = getValue + this["value_or"] = getValueOr + } + + override fun typeName() = "Optional" + + override fun toString() = "Optional(${value?.toString() ?: ""})" + + companion object { + private val hasValue = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is OptionalValue) + BoolValue(thisValue.value != null) + } + + private val getValue = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is OptionalValue) + thisValue.value!! + } + + private val getValueOr = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is OptionalValue) + thisValue.value ?: arguments[0] + } + } +} + +data class ArrayIterator( + val array: ArrayValue, + private var nextIndex: Int = 0, + private val endInclusiveIndex: Int = array.values.lastIndex, +) : Value() { + init { + this["next"] = next + } + + override fun typeName() = "ArrayIterator" + + override fun toString() = "ArrayIterator($nextIndex..$endInclusiveIndex)" + + companion object { + private val next = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is ArrayIterator) + if (thisValue.nextIndex > thisValue.endInclusiveIndex) { + OptionalValue(null) + } else OptionalValue(thisValue.array.values[thisValue.nextIndex]).also { + thisValue.nextIndex += 1 + } + } + } +} + +data class ArrayValue(val values: MutableList) : Value() { + init { + this["is_empty"] = isEmpty + this["size"] = size + this["contains"] = contains + this["iterator"] = iterator + this["push"] = push + this["pop"] = pop + this["first"] = first + this["last"] = last + } + + override fun typeName() = "Array" + + override fun toString() = values.joinToString(prefix = "[", postfix = "]") + + companion object { + private val isEmpty = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is ArrayValue) + BoolValue(thisValue.values.isEmpty()) + } + + private val size = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is ArrayValue) + IntegerValue(thisValue.values.size.toLong()) + } + + private val contains = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is ArrayValue) + BoolValue(arguments[0] in thisValue.values) + } + + private val iterator = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is ArrayValue) + ArrayIterator(thisValue) + } + + private val push = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is ArrayValue) + thisValue.values.add(arguments[0]) + VoidValue + } + + private val pop = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is ArrayValue) + if (thisValue.values.isEmpty()) { + OptionalValue(null) + } else OptionalValue(thisValue.values.removeLast()) + } + + private val first = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is ArrayValue) + OptionalValue(thisValue.values.firstOrNull()) + } + + private val last = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is ArrayValue) + OptionalValue(thisValue.values.lastOrNull()) + } + } +} + +data class ArraySlice(val array: ArrayValue, val range: IntRange) : Value() { + init { + this["is_empty"] = isEmpty + this["contains"] = contains + this["size"] = size + this["iterator"] = iterator + this["to_array"] = toArray + this["first"] = first + this["last"] = last + } + + override fun typeName() = "ArraySlice" + + override fun toString() = "ArraySlice($range)" + + companion object { + private val isEmpty = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is ArraySlice) + BoolValue(thisValue.range.isEmpty() || thisValue.array.values.slice(thisValue.range).isEmpty()) + } + + private val contains = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is ArraySlice) + BoolValue(arguments[0] in thisValue.array.values.slice(thisValue.range)) + } + + private val size = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is ArraySlice) + + // Note: this is what Jakt does, not totally sure why + if (thisValue.range.last > thisValue.array.values.size) { + IntegerValue(0) + } else IntegerValue((thisValue.range.last - thisValue.range.first).toLong()) + } + + private val iterator = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is ArraySlice) + ArrayIterator(thisValue.array, thisValue.range.first, thisValue.range.last) + } + + private val toArray = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is ArraySlice) + ArrayValue(thisValue.array.values.slice(thisValue.range).toMutableList()) + } + + private val first = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is ArraySlice) + thisValue.array.values[thisValue.range.first] + } + + private val last = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is ArraySlice) + thisValue.array.values[thisValue.range.last] + } + } +} + +object StringStruct : Value() { + private val repeated = BuiltinFunction(2) { thisValue, arguments -> + require(thisValue == null) + val (char, count) = arguments + require(char is CharValue && count is IntegerValue) + StringValue(buildString { + repeat(count.value.toInt()) { append(char.value) } + }) + } + + private val number = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue == null) + val arg = arguments[0] + require(arg is IntegerValue) + StringValue(arg.value.toString()) + } + + init { + this["number"] = number + this["repeated"] = repeated + } + + override fun typeName() = "String" + + override fun toString() = "StringStruct" +} + +data class StringValue(override val value: String) : Value(), PrimitiveValue { + init { + this["is_empty"] = isEmpty + this["length"] = length + this["hash"] = hash + this["substring"] = substring + this["to_uint"] = toUInt + this["to_int"] = toInt + this["is_whitespace"] = isWhitespace + this["contains"] = contains + this["replace"] = replace + this["byte_at"] = byteAt + this["split"] = split + this["starts_with"] = startsWith + this["ends_with"] = endsWith + } + + override fun typeName() = "String" + + override fun toString() = "\"$value\"" + + companion object { + private val whiteSpace = setOf(' ', '\t', '\n', 0xb.toChar() /* \v */, 0xc.toChar() /* \f */, '\r') + + private val isEmpty = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is StringValue) + BoolValue(thisValue.value.isEmpty()) + } + + private val length = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is StringValue) + IntegerValue(thisValue.value.length.toLong()) + } + + private val hash = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is StringValue) + val string = thisValue.value + + // See runtime/Jakt/StringHash.h + if (string.isEmpty()) + return@BuiltinFunction IntegerValue(0L) + + var hash = string.length + + for (ch in string) { + hash += ch.code + hash += hash shl 10 + hash = hash or (hash shr 6) + } + + hash += hash shl 3 + hash = hash or (hash shr 11) + hash += hash shl 15 + + IntegerValue(hash.toLong()) + } + + private val substring = BuiltinFunction(2) { thisValue, arguments -> + require(thisValue is StringValue) + val (start, end) = arguments + require(start is IntegerValue && end is IntegerValue) + StringValue(thisValue.value.substring(start.value.toInt(), (start.value + end.value).toInt())) + } + + private val toUInt = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is StringValue) + OptionalValue(thisValue.value.toLongOrNull()?.let(::IntegerValue)) + } + + private val toInt = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is StringValue) + OptionalValue(thisValue.value.toLongOrNull()?.let(::IntegerValue)) + } + + private val isWhitespace = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is StringValue) + BoolValue(thisValue.value.all { it in whiteSpace }) + } + + private val contains = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is StringValue) + val arg = arguments[0] + require(arg is StringValue) + BoolValue(arg.value in thisValue.value) + } + + private val replace = BuiltinFunction(2) { thisValue, arguments -> + require(thisValue is StringValue) + val (replace, with) = arguments + require(replace is StringValue && with is StringValue) + StringValue(thisValue.value.replace(replace.value, with.value)) + } + + private val byteAt = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is StringValue) + val arg = arguments[0] + require(arg is IntegerValue) + IntegerValue(thisValue.value.toByteArray()[arg.value.toInt()].toLong()) + } + + private val split = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is StringValue) + val arg = arguments[0] + require(arg is CharValue) + ArrayValue(thisValue.value.split(arg.value).map(::StringValue).toMutableList()) + } + + private val startsWith = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is StringValue) + val arg = arguments[0] + require(arg is StringValue) + BoolValue(thisValue.value.startsWith(arg.value)) + } + + private val endsWith = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is StringValue) + val arg = arguments[0] + require(arg is StringValue) + BoolValue(thisValue.value.endsWith(arg.value)) + } + } +} + +object StringBuilderStruct : Value() { + private val create = BuiltinFunction(0) { thisValue, arguments -> + require(thisValue == null) + require(arguments.isEmpty()) + StringBuilderInstance() + } + + init { + this["create"] = create + } + + override fun typeName() = "StringBuilder" + + override fun toString() = "StringBuilderStruct" +} + +class StringBuilderInstance : Value() { + val builder = StringBuilder() + + init { + this["append"] = append + this["append_string"] = appendString + this["append_code_point"] = appendCodePoint + this["to_string"] = toString + this["is_empty"] = isEmpty + this["length"] = length + this["clear"] = clear + } + + override fun typeName() = "StringBuilder" + + override fun toString() = "StringBuilder(\"$builder\")" + + companion object { + private val append = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is StringBuilderInstance) + val arg = arguments[0] + require(arg is ByteCharValue) + + thisValue.builder.appendCodePoint(arg.value.toInt()) + VoidValue + } + + private val appendString = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is StringBuilderInstance) + val arg = arguments[0] + require(arg is StringValue) + + thisValue.builder.append(arg.value) + VoidValue + } + + private val appendCodePoint = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is StringBuilderInstance) + val arg = arguments[0] + require(arg is IntegerValue) + + thisValue.builder.appendCodePoint(arg.value.toInt()) + VoidValue + } + + private val toString = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is StringBuilderInstance) + StringValue(thisValue.builder.toString()) + } + + private val isEmpty = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is StringBuilderInstance) + BoolValue(thisValue.builder.isEmpty()) + } + + private val length = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is StringBuilderInstance) + IntegerValue(thisValue.builder.length.toLong()) + } + + private val clear = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is StringBuilderInstance) + thisValue.builder.clear() + VoidValue + } + } +} + +// TODO: no clue if this is works the way it does in Jakt +data class DictionaryIterator(val dictionary: DictionaryValue) : Value() { + private val remainingKeys = dictionary.elements.keys + + init { + this["next"] = next + } + + override fun typeName() = "DictionaryIterator" + + override fun toString() = "DictionaryIterator" + + companion object { + private val next = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is DictionaryIterator) + val nextKey = thisValue.remainingKeys.random() + thisValue.remainingKeys.remove(nextKey) + TupleValue(mutableListOf(nextKey, thisValue.dictionary.elements[nextKey]!!)) + } + } +} + +data class DictionaryValue(val elements: MutableMap) : Value() { + init { + this["is_empty"] = isEmpty + this["get"] = get + this["contains"] = contains + this["set"] = set + this["remove"] = remove + this["clear"] = clear + this["size"] = size + this["keys"] = keys + this["iterator"] = iterator + } + + override fun typeName() = "Dictionary" + + override fun toString() = elements.entries.joinToString(prefix = "{", postfix = "}") { + "${it.key}: ${it.value}" + } + + companion object { + private val isEmpty = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is DictionaryValue) + BoolValue(thisValue.elements.isEmpty()) + } + + private val get = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is DictionaryValue) + require(arguments.size == 1) + OptionalValue(thisValue.elements[arguments[0]]) + } + + private val contains = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is DictionaryValue) + BoolValue(arguments[0] in thisValue.elements) + } + + private val set = BuiltinFunction(2) { thisValue, arguments -> + require(thisValue is DictionaryValue) + thisValue.elements[arguments[0]] = arguments[1] + VoidValue + } + + private val remove = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is DictionaryValue) + BoolValue(thisValue.elements.remove(arguments[0]) != null) + } + + private val clear = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is DictionaryValue) + thisValue.elements.clear() + VoidValue + } + + private val size = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is DictionaryValue) + IntegerValue(thisValue.elements.size.toLong()) + } + + // TODO: What is the ordering guarantee for this in Jakt? + private val keys = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is DictionaryValue) + ArrayValue(thisValue.elements.keys.toMutableList()) + } + + private val iterator = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is DictionaryValue) + DictionaryIterator(thisValue) + } + } +} + +data class SetIterator(val values: MutableSet) : Value() { + init { + this["next"] = next + } + + override fun typeName() = "SetIterator" + + override fun toString() = "SetIterator" + + companion object { + private val next = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is SetIterator) + val nextValue = thisValue.values.random() + thisValue.values.remove(nextValue) + nextValue + } + } +} + +data class SetValue(val values: MutableSet) : Value() { + init { + this["is_empty"] = isEmpty + this["contains"] = contains + this["add"] = add + this["remove"] = remove + this["clear"] = clear + this["size"] = size + this["iterator"] = iterator + } + + override fun typeName() = "Set" + + override fun toString() = values.joinToString(prefix = "{", postfix = "}") + + companion object { + private val isEmpty = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is SetValue) + BoolValue(thisValue.values.isEmpty()) + } + + private val contains = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is SetValue) + BoolValue(arguments[0] in thisValue.values) + } + + private val add = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is SetValue) + thisValue.values.add(arguments[0]) + VoidValue + } + + private val remove = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue is SetValue) + BoolValue(thisValue.values.remove(arguments[0])) + } + + private val clear = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is SetValue) + thisValue.values.clear() + VoidValue + } + + private val size = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is SetValue) + IntegerValue(thisValue.values.size.toLong()) + } + + private val iterator = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is SetValue) + SetIterator(thisValue.values.toMutableSet()) + } + } +} + +data class RangeValue(val start: Long, val end: Long, val isInclusive: Boolean) : Value() { + private var current = start + private val forwards = start <= end + + val range: IntRange + get() = if (isInclusive) start.toInt()..end.toInt() else start.toInt() until end.toInt() + + init { + this["next"] = next + this["inclusive"] = inclusive + this["exclusive"] = exclusive + } + + override fun typeName() = "Range" + + override fun toString() = "Range($start..$end, inclusive = $isInclusive)" + + private fun getAndAdvance(): Long { + return current.also { + if (forwards) current++ else current-- + } + } + + private fun isDone(): Boolean { + return when { + forwards && isInclusive -> current > end + forwards -> current >= end + isInclusive -> current < start + else -> current <= start + } + } + + companion object { + private val next = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is RangeValue) + if (thisValue.isDone()) { + OptionalValue(null) + } else OptionalValue(IntegerValue(thisValue.getAndAdvance())) + } + + private val inclusive = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is RangeValue) + thisValue.copy(isInclusive = true) + } + + private val exclusive = BuiltinFunction(0) { thisValue, _ -> + require(thisValue is RangeValue) + thisValue.copy(isInclusive = false) + } + } +} + +object ErrorStruct : Value() { + private val fromErrno = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue == null) + val arg = arguments[0] + require(arg is IntegerValue) + ErrorInstance(arg.value) + } + + init { + this["from_errno"] = fromErrno + } + + override fun typeName() = "Error" + + override fun toString() = "ErrorStruct" +} + +class ErrorInstance(private val codeValue: Long) : Value() { + override fun typeName() = "Error" + + override fun toString() = "Error(code = $codeValue)" +} + +object FileStruct : Value() { + private val exists = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue == null) + val arg = arguments[0] + require(arg is StringValue) + BoolValue(File(arg.value).exists()) + } + + private val openForReading = BuiltinFunction(1) { thisValue, arguments -> + require(thisValue == null) + val arg = arguments[0] + require(arg is StringValue) + FileInstance(File(arg.value)) + } + + init { + this["exists"] = exists + this["open_for_reading"] = openForReading + } + + override fun typeName() = "File" + + override fun toString() = "FileStruct" +} + +class FileInstance(val file: File) : Value() { + init { + this["read_all"] = readAll + } + + override fun typeName() = "File" + + override fun toString() = "File($file)" + + companion object { + private val readAll = BuiltinFunction(0) { thisValue, arguments -> + require(thisValue is FileInstance) + require(arguments.isEmpty()) + StringValue(thisValue.file.readText()) + } + } +} + +// Free functions + +val jaktGetTargetTripleStringFunction = BuiltinFunction(0) { _, _ -> + JaktProjectListener.targetTriple.get() ?: StringValue("unknown-unknown-unknown-unknown") +} + +val abortFunction = BuiltinFunction(0) { _, _ -> error("aborted") } + +abstract class FormatLikeFunction : FunctionValue(1, Int.MAX_VALUE) { + protected fun getFormatString(arguments: List): String { + require(arguments.isNotEmpty()) + val fmtSpecifier = arguments[0] + require(fmtSpecifier is StringValue) + + val fmtString = FormatStringParser(fmtSpecifier.value).parse() + return fmtString.apply(arguments.drop(1)) + } +} + +object FormatFunction : FormatLikeFunction() { + override fun call(interpreter: Interpreter, thisValue: Value?, arguments: List): ExecutionResult { + return ExecutionResult.Normal(StringValue(getFormatString(arguments))) + } +} + +object PrintFunction : FormatLikeFunction() { + override fun call(interpreter: Interpreter, thisValue: Value?, arguments: List): ExecutionResult { + interpreter.stdout.append(getFormatString(arguments)) + return ExecutionResult.Normal(VoidValue) + } +} + +object PrintlnFunction : FormatLikeFunction() { + override fun call(interpreter: Interpreter, thisValue: Value?, arguments: List): ExecutionResult { + interpreter.stdout.append(getFormatString(arguments)) + interpreter.stdout.append('\n') + return ExecutionResult.Normal(VoidValue) + } +} + +object EprintFunction : FormatLikeFunction() { + override fun call(interpreter: Interpreter, thisValue: Value?, arguments: List): ExecutionResult { + interpreter.stderr.append(getFormatString(arguments)) + return ExecutionResult.Normal(VoidValue) + } +} + +object EprintlnFunction : FormatLikeFunction() { + override fun call(interpreter: Interpreter, thisValue: Value?, arguments: List): ExecutionResult { + interpreter.stderr.append(getFormatString(arguments)) + interpreter.stderr.append('\n') + return ExecutionResult.Normal(VoidValue) + } +} + +// TODO: saturated/truncated functions when we have generic information + diff --git a/src/main/kotlin/org/serenityos/jakt/comptime/extensions.kt b/src/main/kotlin/org/serenityos/jakt/comptime/extensions.kt new file mode 100644 index 00000000..e1343727 --- /dev/null +++ b/src/main/kotlin/org/serenityos/jakt/comptime/extensions.kt @@ -0,0 +1,37 @@ +package org.serenityos.jakt.comptime + +import com.intellij.psi.PsiElement +import com.intellij.psi.util.childrenOfType +import org.serenityos.jakt.psi.JaktPsiElement +import org.serenityos.jakt.psi.ancestors +import org.serenityos.jakt.psi.api.JaktBlock +import org.serenityos.jakt.psi.api.JaktCallExpression +import org.serenityos.jakt.psi.api.JaktIfStatement +import org.serenityos.jakt.psi.api.JaktVariableDeclarationStatement +import org.serenityos.jakt.psi.caching.comptimeCache +import org.serenityos.jakt.psi.findChildOfType + +fun JaktPsiElement.performComptimeEvaluation(): Interpreter.Result { + return comptimeCache().resolveWithCaching(this) { + Interpreter.evaluate(this) + } +} + +fun PsiElement.getComptimeTargetElement(): JaktPsiElement? { + val baseElement = this + + return baseElement.ancestors(withSelf = true).firstOrNull { + when (it) { + is JaktCallExpression, + is JaktVariableDeclarationStatement -> true + else -> false + } + } as? JaktPsiElement +} + +// Utility accessors +val JaktIfStatement.ifStatement: JaktIfStatement? + get() = findChildOfType() + +val JaktIfStatement.elseBlock: JaktBlock? + get() = childrenOfType().getOrNull(1) diff --git a/src/main/kotlin/org/serenityos/jakt/comptime/formatStrings.kt b/src/main/kotlin/org/serenityos/jakt/comptime/formatStrings.kt new file mode 100644 index 00000000..6822281b --- /dev/null +++ b/src/main/kotlin/org/serenityos/jakt/comptime/formatStrings.kt @@ -0,0 +1,378 @@ +package org.serenityos.jakt.comptime + +import com.intellij.openapi.util.Ref +import org.serenityos.jakt.utils.unreachable + +// Based on https://github.com/SerenityOS/serenity/blob/master/AK/Format.cpp + +data class FormatString( + private val literals: List, + private val specifierStrings: List, +) { + init { + require(specifierStrings.size == literals.size - 1) + } + + fun apply(arguments: List): String { + val context = Context(arguments) + val specifiers = specifierStrings.map { FormatSpecifierParser(it).parse(context) } + + return buildString { + append(literals[0]) + + for (i in specifiers.indices) { + specifiers[i].apply(this, arguments) + append(literals[i + 1]) + } + } + } +} + +data class Specifier( + var index: Int = 0, + var alignment: Alignment = Alignment.Right, + var sign: Sign = Sign.OnlyIfNeeded, + var mode: Mode = Mode.Default, + var alternative: Boolean = false, + var fillChar: Char = ' ', + var zeroPad: Boolean = false, + var width: Int? = null, + var precision: Int? = null, +) { + enum class Alignment { + Left, + Center, + Right, + } + + enum class Sign { + OnlyIfNeeded, + Always, + Reserved, + } + + enum class Mode(val str: kotlin.String) { + Default(""), + Binary("b"), + BinaryUppercase("B"), + Decimal("d"), + Octal("o"), + Hexadecimal("x"), + HexadecimalUppercase("X"), + Character("c"), + String("s"), + Pointer("p"), + Float("f"), + HexFloat("a"), + HexFloatUppercase("A"), + HexDump("hex-dump"), + } + + // TODO: This needs a lot of work to be accurate + fun apply(builder: StringBuilder, arguments: List) { + val target = arguments.getOrNull(index) ?: + error("Format specifier refers to argument ${index + 1}, but only ${arguments.size} arguments were provided") + + check(target is PrimitiveValue) { "Cannot format non-primitive type ${target.typeName()} at comptime" } + + check(mode != Mode.Pointer && mode != Mode.HexDump && mode != Mode.Binary && mode != Mode.BinaryUppercase) { + "Unsupported format specifier '${mode.str}'" + } + + check(alignment != Alignment.Center) { + "Unsupported alignment '^'" + } + + check(fillChar == ' ') { + "Unsupported non-space fill character" + } + + val javaFormatSpecifier = buildString { + append('%') + + if (alignment == Alignment.Left) + append('-') + + if (alternative) + append('#') + + when (sign) { + Sign.Always -> append('+') + Sign.Reserved -> append(' ') + Sign.OnlyIfNeeded -> {} + } + + if (zeroPad) + append('0') + + if (width != null) + append(width) + + if (precision != null) { + append('.') + append(precision) + } + + if (mode == Mode.Default) { + mode = when (target) { + is StringValue -> Mode.String + is IntegerValue -> Mode.Decimal + is FloatValue -> Mode.Float + is CharValue -> Mode.Character + else -> unreachable() + } + } + + append(mode.str) + } + + builder.append(javaFormatSpecifier.format(target.value)) + } +} + +open class GenericParser(val text: String) { + var cursor = 0 + + val done: Boolean + get() = cursor > text.lastIndex + + val char: Char + get() = text[cursor] + + fun consumeNumber(): Int? { + val pos = cursor + + while (!done && char.isDigit()) + cursor++ + + if (pos == cursor) + return null + + return text.substring(pos, cursor).toIntOrNull() ?: run { + cursor = pos + null + } + } + + fun consumeIf(char: Char): Boolean { + return if (matches(char)) { + cursor++ + true + } else false + } + + fun consumeIf(str: String): Boolean { + return if (matches(str)) { + cursor += str.length + true + } else false + } + + fun peek(n: Int = 0) = text.getOrNull(cursor + n) + + fun consume() = char.also { cursor++ } + + fun matches(ch: Char): Boolean { + return !done && ch == char + } + + fun matches(string: String): Boolean { + return text.substring(cursor).startsWith(string) + } +} + +class Context(val arguments: List, private var nextValue: Int = 0) { + fun nextIndex() = nextValue++ +} + +class FormatSpecifierParser(specifier: String) : GenericParser(specifier) { + fun parse(context: Context): Specifier = with(Specifier()) { + index = consumeNumber() ?: context.nextIndex() + + if (!consumeIf(':')) + return@with this + + if ("<^>".contains(peek(1) ?: 'a')) { + check(char !in "{}") { "Malformed specifier \"$text\"" } + fillChar = consume() + } + + alignment = when { + consumeIf('<') -> Specifier.Alignment.Left + consumeIf('^') -> Specifier.Alignment.Center + consumeIf('>') -> Specifier.Alignment.Right + else -> alignment + } + + sign = when { + consumeIf('-') -> Specifier.Sign.OnlyIfNeeded + consumeIf('+') -> Specifier.Sign.Always + consumeIf(' ') -> Specifier.Sign.Reserved + else -> sign + } + + if (consumeIf('#')) + alternative = true + + if (consumeIf('0')) + zeroPad = true + + val index = Ref(null) + if (consumeReplacementField(index)) { + if (index.isNull) + index.set(context.nextIndex()) + + val widthValue = context.arguments.getOrNull(index.get()!!) + check(widthValue != null) { + "Width parameter refers to non-existent argument ${index.get()!!}" + } + check(widthValue is IntegerValue) { + "Expected integer for width argument at index ${index.get()!!}, found ${widthValue.typeName()}" + } + width = widthValue.value.toInt() + } else { + val num = consumeNumber() + if (num != null) + width = num + } + + if (consumeIf('.')) { + if (consumeReplacementField(index)) { + if (index.isNull) + index.set(context.nextIndex()) + + val precisionValue = context.arguments.getOrNull(index.get()!!) + check(precisionValue != null) { + "Precision parameter refers to non-existent argument ${index.get()!!}" + } + check(precisionValue is IntegerValue) { + "Expected integer for precision argument at index ${index.get()!!}, found ${precisionValue.typeName()}" + } + precision = precisionValue.value.toInt() + } else { + val num = consumeNumber() + if (num != null) + precision = num + } + } + + mode = when { + consumeIf('b') -> Specifier.Mode.Binary + consumeIf('B') -> Specifier.Mode.BinaryUppercase + consumeIf('d') -> Specifier.Mode.Decimal + consumeIf('o') -> Specifier.Mode.Octal + consumeIf('x') -> Specifier.Mode.Hexadecimal + consumeIf('X') -> Specifier.Mode.HexadecimalUppercase + consumeIf('c') -> Specifier.Mode.Character + consumeIf('s') -> Specifier.Mode.String + consumeIf('P') -> Specifier.Mode.Pointer + consumeIf('f') -> Specifier.Mode.Float + consumeIf('a') -> Specifier.Mode.HexFloat + consumeIf('A') -> Specifier.Mode.HexFloatUppercase + consumeIf("hex-dump") -> Specifier.Mode.HexDump + matches('}') -> Specifier.Mode.Default + !done -> error("Unknown format specifier '$char'") + else -> mode + } + + check(consumeIf('}')) + + check(done) + + this + } + + private fun consumeReplacementField(ref: Ref): Boolean { + if (!consumeIf('{')) + return false + + ref.set(consumeNumber()) + + check(consumeIf('}')) + + return true + } +} + +class FormatStringParser(formatString: String) : GenericParser(formatString) { + fun parse(): FormatString { + if (done) + return FormatString(listOf(""), emptyList()) + + val literals = mutableListOf() + val specifiers = mutableListOf() + + while (true) { + literals.add(consumeLiteral()) + val specifier = consumeSpecifier() + + if (specifier == null) { + check(done) { + "Expected specifier at offset $cursor" + } + + return FormatString(literals, specifiers) + } + + specifiers.add(specifier) + } + } + + private fun consumeLiteral(): String { + val pos = cursor + + while (!done) { + if (consumeIf("{{")) + continue + + if (consumeIf("}}")) + continue + + if (matches("{") || matches("}")) + return text.substring(pos, cursor) + + cursor++ + } + + return text.substring(pos) + } + + private fun consumeSpecifier(): String? { + require(!matches("}")) { + "Unexpected '}' at offset $cursor" + } + + if (!consumeIf("{")) + return null + + val pos = cursor + + consumeNumber() + + if (consumeIf(":")) { + var level = 1 + + while (level > 0) { + check(!done) { + "Unexpected end of string in format specifier" + } + + if (matches("{")) + level++ + + if (matches("}")) + level-- + + cursor++ + } + + return text.substring(pos, cursor) + } + + check(consumeIf("}")) { + "Expected '}' at offset $cursor" + } + + return "" + } +} diff --git a/src/main/kotlin/org/serenityos/jakt/folding/JaktBlockFoldingBuilder.kt b/src/main/kotlin/org/serenityos/jakt/folding/JaktBlockFoldingBuilder.kt index 4b403b7c..1af3b659 100644 --- a/src/main/kotlin/org/serenityos/jakt/folding/JaktBlockFoldingBuilder.kt +++ b/src/main/kotlin/org/serenityos/jakt/folding/JaktBlockFoldingBuilder.kt @@ -66,6 +66,11 @@ class JaktBlockFoldingBuilder : CustomFoldingBuilder() { descriptors += FoldingDescriptor(o, o.textRange) } + override fun visitMatchExpression(o: JaktMatchExpression) { + val body = o.matchBody ?: return + descriptors += FoldingDescriptor(o, TextRange(body.curlyOpen.startOffset, body.curlyClose.endOffset)) + } + private fun FoldingDescriptor(element: PsiElement) = FoldingDescriptor(element, element.textRange) } } diff --git a/src/main/kotlin/org/serenityos/jakt/intentions/ImportNSDeclarationIntention.kt b/src/main/kotlin/org/serenityos/jakt/intentions/ImportNSDeclarationIntention.kt index 7709dabf..bd68aba9 100644 --- a/src/main/kotlin/org/serenityos/jakt/intentions/ImportNSDeclarationIntention.kt +++ b/src/main/kotlin/org/serenityos/jakt/intentions/ImportNSDeclarationIntention.kt @@ -8,8 +8,8 @@ import org.serenityos.jakt.psi.JaktPsiFactory import org.serenityos.jakt.psi.ancestorOfType import org.serenityos.jakt.psi.api.JaktImport import org.serenityos.jakt.psi.api.JaktPlainQualifier -import org.serenityos.jakt.psi.declaration.aliasIdent -import org.serenityos.jakt.psi.declaration.nameIdent +import org.serenityos.jakt.psi.declaration.aliasString +import org.serenityos.jakt.psi.declaration.targetString import org.serenityos.jakt.psi.reference.index class ImportNSDeclarationIntention : JaktIntention("Add import for member") { @@ -27,7 +27,7 @@ class ImportNSDeclarationIntention : JaktIntention() - .find { it.nameIdent.text == resolvedFile.name.substringBefore(".jakt") } + .find { it.targetString == resolvedFile.name.substringBefore(".jakt") } ?: return null description = "Add import for \"${qualifier.text}\"" @@ -51,9 +51,9 @@ class ImportNSDeclarationIntention : JaktIntention if (is64Bit) "i686-pc-windows-msvc" else "x86_64-pc-windows-msvc" + "linux" in name -> "x86_64-pc-linux-gnu" + "bsd" in name -> "x86_64-pc-bsd-unknown" + "mac" in name || "darwin" in name -> "x86_64-apple-darwin-unknown" + "unix" in name -> "x86_64-pc-unix-unknown" + else -> "unknown-unknown-unknown-unknown" + } + + targetTriple.set(StringValue(tripleGuess)) + } + + companion object { + // The system triple (ex: "x64_64-pc-linux-gnu"). This is used in comptime + // execution to populate the ___jakt_get_target_triple_string function + val targetTriple = AtomicReference() } } diff --git a/src/main/kotlin/org/serenityos/jakt/psi/caching/JaktCache.kt b/src/main/kotlin/org/serenityos/jakt/psi/caching/JaktCache.kt index 26012763..0f737e6e 100644 --- a/src/main/kotlin/org/serenityos/jakt/psi/caching/JaktCache.kt +++ b/src/main/kotlin/org/serenityos/jakt/psi/caching/JaktCache.kt @@ -39,6 +39,17 @@ abstract class JaktCache(project: Project) : Disposable { return map.getOrPut(key) { resolver(key) } as V } + @Suppress("UNCHECKED_CAST") + fun resolve(key: K): V? { + ProgressManager.checkCanceled() + return getCacheFor(key)[key] as V? + } + + fun set(key: K, value: V) { + ProgressManager.checkCanceled() + getCacheFor(key)[key] = value + } + private fun getCacheFor(element: PsiElement): ConcurrentMap { val owner = element.modificationBoundary @@ -55,6 +66,8 @@ abstract class JaktCache(project: Project) : Disposable { class JaktResolveCache(project: Project) : JaktCache(project) class JaktTypeCache(project: Project) : JaktCache(project) +class JaktComptimeCache(project: Project) : JaktCache(project) fun PsiElement.resolveCache() = project.service() fun PsiElement.typeCache() = project.service() +fun PsiElement.comptimeCache() = project.service() diff --git a/src/main/kotlin/org/serenityos/jakt/psi/declaration/JaktImportMixin.kt b/src/main/kotlin/org/serenityos/jakt/psi/declaration/JaktImportMixin.kt index b29a2673..c8909298 100644 --- a/src/main/kotlin/org/serenityos/jakt/psi/declaration/JaktImportMixin.kt +++ b/src/main/kotlin/org/serenityos/jakt/psi/declaration/JaktImportMixin.kt @@ -1,12 +1,15 @@ package org.serenityos.jakt.psi.declaration import com.intellij.lang.ASTNode -import com.intellij.psi.PsiElement import com.intellij.psi.stubs.IStubElementType import org.serenityos.jakt.JaktFile import org.serenityos.jakt.JaktTypes +import org.serenityos.jakt.comptime.ArrayValue +import org.serenityos.jakt.comptime.StringValue +import org.serenityos.jakt.comptime.performComptimeEvaluation import org.serenityos.jakt.project.jaktProject import org.serenityos.jakt.psi.api.JaktImport +import org.serenityos.jakt.psi.caching.typeCache import org.serenityos.jakt.psi.findChildrenOfType import org.serenityos.jakt.psi.named.JaktStubbedNamedElement import org.serenityos.jakt.psi.reference.singleRef @@ -24,11 +27,29 @@ abstract class JaktImportMixin : JaktStubbedNamedElement, JaktIm override fun getReference() = singleRef { resolveFile() } } -val JaktImport.nameIdent: PsiElement - get() = originalElement.findChildrenOfType(JaktTypes.IDENTIFIER).first() +// Note that JaktImport doesn't have a type, and thus doesn't use the typeCache normally. So we +// can repurpose it to save this string +val JaktImport.targetString: String + get() = typeCache().resolveWithCaching(this) { + importTarget.identifier?.let { return@resolveWithCaching it.text } -val JaktImport.aliasIdent: PsiElement? - get() = originalElement.findChildrenOfType(JaktTypes.IDENTIFIER).getOrNull(1) + when (val comptimeValue = importTarget.callExpression?.performComptimeEvaluation()?.value) { + is StringValue -> comptimeValue.value + is ArrayValue -> { + val project = jaktProject + val containingFile = containingFile.originalFile.virtualFile + for (value in comptimeValue.values) { + if (value is StringValue && project.resolveImportedFile(containingFile, value.value) != null) + return@resolveWithCaching value.value + } + "__UNKNOWN" + } + else -> "__UNKNOWN" + } + } + +val JaktImport.aliasString: String? + get() = originalElement.findChildrenOfType(JaktTypes.IDENTIFIER).firstOrNull()?.text fun JaktImport.resolveFile(): JaktFile? = - jaktProject.resolveImportedFile(containingFile.originalFile.virtualFile, nameIdent.text) + jaktProject.resolveImportedFile(containingFile.originalFile.virtualFile, targetString) diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml index f6d0f259..24883eaa 100644 --- a/src/main/resources/META-INF/plugin.xml +++ b/src/main/resources/META-INF/plugin.xml @@ -170,6 +170,7 @@ + + + + + + diff --git a/src/main/resources/grammar/Jakt.bnf b/src/main/resources/grammar/Jakt.bnf index fc9679c9..567e875d 100644 --- a/src/main/resources/grammar/Jakt.bnf +++ b/src/main/resources/grammar/Jakt.bnf @@ -102,12 +102,13 @@ upper NamespaceDeclaration ::= NAMESPACE_KEYWORD IDENTIFIER NamespaceBody { NamespaceBody ::= CURLY_OPEN NL TopLevelDefinitionList? NL CURLY_CLOSE // Import Statement -upper Import ::= IMPORT_KEYWORD !EXTERN_KEYWORD IDENTIFIER ImportAs? ImportBraceList? { +upper Import ::= IMPORT_KEYWORD !EXTERN_KEYWORD ImportTarget ImportAs? ImportBraceList? { implements="org.serenityos.jakt.psi.declaration.JaktDeclaration" mixin="org.serenityos.jakt.psi.declaration.JaktImportMixin" stubClass="org.serenityos.jakt.stubs.JaktImportStub" elementTypeFactory="org.serenityos.jakt.stubs.JaktStubFactoryKt.jaktStubFactory" } +ImportTarget ::= IDENTIFIER !PAREN_OPEN | CallExpression private ImportAs ::= KEYWORD_AS IDENTIFIER ImportBraceList ::= CURLY_OPEN NL <>? NL CURLY_CLOSE ImportBraceEntry ::= IDENTIFIER { diff --git a/src/test/kotlin/org/serenityos/jakt/JaktBaseTest.kt b/src/test/kotlin/org/serenityos/jakt/JaktBaseTest.kt index 8abbcedd..4ddc75ff 100644 --- a/src/test/kotlin/org/serenityos/jakt/JaktBaseTest.kt +++ b/src/test/kotlin/org/serenityos/jakt/JaktBaseTest.kt @@ -27,9 +27,8 @@ abstract class JaktBaseTest : BasePlatformTestCase() { tagElements.forEach { comment -> val matches = tagRegex.findAll(comment.text).toList() - check(matches.isNotEmpty()) { - "Comment in test with no tags" - } + if (matches.isEmpty()) + return@forEach for (match in matches) { val group = match.groups[1] ?: error("Invalid tag comment") @@ -52,6 +51,10 @@ abstract class JaktBaseTest : BasePlatformTestCase() { } } + check(taggedElements.isNotEmpty()) { + "No tagged elements found" + } + return taggedElements } diff --git a/src/test/kotlin/org/serenityos/jakt/comptime/JaktBasicComptimeTest.kt b/src/test/kotlin/org/serenityos/jakt/comptime/JaktBasicComptimeTest.kt new file mode 100644 index 00000000..cdff0a18 --- /dev/null +++ b/src/test/kotlin/org/serenityos/jakt/comptime/JaktBasicComptimeTest.kt @@ -0,0 +1,18 @@ +package org.serenityos.jakt.comptime + +class JaktBasicComptimeTest : JaktComptimeTest() { + fun `test bitwise operators`() = doStdoutTest(""" + comptime bitwise() { + if (((0x123 ^ 0x456) << 12) | (0x789 & 0xabc)) == 0x575288 { + print("PASS") + } else { + print("FAIL") + } + } + + function main() { + bitwise() + //^T + } + """.trimIndent(), "PASS") +} diff --git a/src/test/kotlin/org/serenityos/jakt/comptime/JaktComptimeTest.kt b/src/test/kotlin/org/serenityos/jakt/comptime/JaktComptimeTest.kt new file mode 100644 index 00000000..567cd6ad --- /dev/null +++ b/src/test/kotlin/org/serenityos/jakt/comptime/JaktComptimeTest.kt @@ -0,0 +1,41 @@ +package org.serenityos.jakt.comptime + +import org.intellij.lang.annotations.Language +import org.serenityos.jakt.JaktBaseTest + +abstract class JaktComptimeTest : JaktBaseTest() { + protected fun doTest(@Language("Jakt") text: String, test: (Interpreter.Result) -> Unit) { + setupFor(text) + + val taggedElements = extractTaggedElements() + val elements = taggedElements["T"] + check(!elements.isNullOrEmpty()) { + "No tagged elements found" + } + + check(elements.size == 1) { + "More than one tagged element found" + } + + val targetElement = elements.single().getComptimeTargetElement() + check(targetElement != null) { + "Element cannot be evaluated at comptime" + } + + test(Interpreter.evaluate(targetElement)) + } + + protected fun doStdoutTest(@Language("Jakt") text: String, expectedStdout: String) = doTest(text) { + check(expectedStdout == it.stdout) + } + + protected fun doStderrTest(@Language("Jakt") text: String, expectedStderr: String) = doTest(text) { + check(expectedStderr == it.stderr) + } + + protected fun doValueTest(@Language("Jakt") text: String, expectedValue: Value) = doTest(text) { + check(it.value == expectedValue) { + "Expected $expectedValue, found ${it.value}" + } + } +} diff --git a/src/test/kotlin/org/serenityos/jakt/comptime/JaktStringComptimeTest.kt b/src/test/kotlin/org/serenityos/jakt/comptime/JaktStringComptimeTest.kt new file mode 100644 index 00000000..d8e9bdd7 --- /dev/null +++ b/src/test/kotlin/org/serenityos/jakt/comptime/JaktStringComptimeTest.kt @@ -0,0 +1,53 @@ +package org.serenityos.jakt.comptime + +class JaktStringComptimeTest : JaktComptimeTest() { + fun `test string methods`() = doStdoutTest(""" + comptime empty() throws => "".is_empty() + comptime length() throws => "a string of length 21".length() + comptime substring() throws => "abcdef".substring(start: 1, length: 3) + comptime hash() throws => "well, hello friends".hash() + comptime number() throws => String::number(123) + comptime to_uint() throws => "123".to_uint() + comptime to_int() throws => "-456".to_int() + comptime is_whitespace() throws => " ".is_whitespace() and not "abc".is_whitespace() + comptime contains() throws => "abcdef".contains("bcd") + comptime replace() throws => "well, hiya friends".replace(replace: "hiya", with: "hello") + comptime byte_at() throws => "AAAA".byte_at(3) + comptime starts_with() throws => "abcdef".starts_with("abc") + comptime ends_with() throws => "abcdef".ends_with("def") + comptime repeated() throws => String::repeated(character: 'A', count: 5) + comptime split() throws => "a;b;c".split(';') + + comptime test() throws { + mut success = empty() + success = success and length() == 21 + success = success and substring() == "bcd" + success = success and hash() == "well, hello friends".hash() + success = success and number() == "123" + success = success and to_uint()! == 123u32 + success = success and to_int()! == -456i32 + success = success and is_whitespace() + success = success and contains() + success = success and replace() == "well, hello friends" + success = success and byte_at() == 0x41 + success = success and starts_with() + success = success and ends_with() + success = success and repeated() == "AAAAA" + let parts = split() + success = success and parts[0] == "a" + success = success and parts[1] == "b" + success = success and parts[2] == "c" + + if success { + print("PASS") + } else { + print("FAIL") + } + } + + function main() { + test() + //^T + } + """, "PASS") +}