From 75855514e0423b6ad9323d83e9e651ece163e320 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Tue, 2 Jan 2024 20:32:14 +0800 Subject: [PATCH] feat(rust): add handle for function calling in type --- .../chapi/ast/rustast/RustAstBaseListener.kt | 30 +++++++++++-- .../ast/rustast/RustFullIdentListener.kt | 12 +++-- .../ast/rustast/RustFullIdentListenerTest.kt | 44 +++++++++++++++++++ 3 files changed, 80 insertions(+), 6 deletions(-) diff --git a/chapi-ast-rust/src/main/kotlin/chapi/ast/rustast/RustAstBaseListener.kt b/chapi-ast-rust/src/main/kotlin/chapi/ast/rustast/RustAstBaseListener.kt index 5051d06f..cf4c84e1 100644 --- a/chapi-ast-rust/src/main/kotlin/chapi/ast/rustast/RustAstBaseListener.kt +++ b/chapi-ast-rust/src/main/kotlin/chapi/ast/rustast/RustAstBaseListener.kt @@ -43,13 +43,22 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis */ var localVars: MutableMap = mutableMapOf() + /** + * localFunctions will store all local functions in the current scope + * + * for individual function: currentFunction.Name + * for Implementation: currentNode.NodeName + "::" + currentFunction.Name + * + */ + var localFunctions: MutableMap = mutableMapOf() + /** * packageName will parse from fileName, like: * - "src/main.rs" -> "main" * - "enfer_core/src/lib.rs" -> "enfer_core" * - "enfer_core/src/document.rs" -> "enfer_core::document" */ - private val packageName: String + val packageName: String get() { val modulePath = fileName.substringBeforeLast("src") .substringBeforeLast(File.separator) @@ -195,7 +204,7 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis open fun lookupByType(typeText: String?): String { if (typeText == null) return "" - val text = if (typeText.contains("::")) { + var text = if (typeText.contains("::")) { val crateName = typeText.split("::").first() if (crateName == "std" || crateName == "crate") { typeText @@ -206,6 +215,19 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis typeText } + if (localVars.containsKey(text)) { + return localVars[text] ?: "" + } + + if (localFunctions.containsKey(text)) { + val function = localFunctions[text]!! + text = if (function.MultipleReturns.isNotEmpty()) { + function.MultipleReturns.first().TypeType + } else { + function.ReturnType + } + } + imports.filter { it.AsName == text }.forEach { return it.Source } @@ -388,7 +410,7 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis generics + typeList.map { CodeProperty( TypeType = lookupType(it), - TypeValue = lookupType(it) + TypeValue = it.text, ) } }.flatten() @@ -410,8 +432,10 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis if (isEnteredImplementation == false) { isEnteredIndividualFunction = false individualFunctions.add(currentIndividualFunction) + localFunctions[currentIndividualFunction.Name] = currentIndividualFunction } else { currentNode.Functions += currentFunction + localFunctions[currentNode.NodeName + "::" + currentFunction.Name] = currentFunction } localVars.clear() diff --git a/chapi-ast-rust/src/main/kotlin/chapi/ast/rustast/RustFullIdentListener.kt b/chapi-ast-rust/src/main/kotlin/chapi/ast/rustast/RustFullIdentListener.kt index 4a55730b..762bf209 100644 --- a/chapi-ast-rust/src/main/kotlin/chapi/ast/rustast/RustFullIdentListener.kt +++ b/chapi-ast-rust/src/main/kotlin/chapi/ast/rustast/RustFullIdentListener.kt @@ -29,9 +29,15 @@ class RustFullIdentListener(fileName: String) : RustAstBaseListener(fileName) { override fun enterCallExpression(ctx: RustParser.CallExpressionContext?) { val functionName = ctx?.expression()?.text val split = functionName?.split("::") - val nodeName = split?.dropLast(1)?.joinToString("::") ?: "" + val lastType = split?.dropLast(1)?.joinToString("::") + val nodeName = if (lastType.isNullOrEmpty()) { + split?.firstOrNull() ?: "" + } else { + lastType + } + functionInstance.FunctionCalls += CodeCall( - Package = split?.dropLast(1)?.joinToString("::") ?: "", + Package = lastType ?: packageName, NodeName = lookupByType(nodeName), FunctionName = split?.last() ?: "", OriginNodeName = nodeName, @@ -52,7 +58,7 @@ class RustFullIdentListener(fileName: String) : RustAstBaseListener(fileName) { // todo: handle method call functionInstance.FunctionCalls += CodeCall( - Package = nodeName, + Package = packageName, NodeName = lookupByType(nodeName), OriginNodeName = instanceVar.ifEmpty { nodeName }, FunctionName = functionName, diff --git a/chapi-ast-rust/src/test/kotlin/chapi/ast/rustast/RustFullIdentListenerTest.kt b/chapi-ast-rust/src/test/kotlin/chapi/ast/rustast/RustFullIdentListenerTest.kt index 57441a84..37d8c094 100644 --- a/chapi-ast-rust/src/test/kotlin/chapi/ast/rustast/RustFullIdentListenerTest.kt +++ b/chapi-ast-rust/src/test/kotlin/chapi/ast/rustast/RustFullIdentListenerTest.kt @@ -485,4 +485,48 @@ fn main() { assertEquals("std::sync::Arc", firstFunction.MultipleReturns[1].TypeType) assertEquals("embedding::semantic::SemanticError", firstFunction.MultipleReturns[2].TypeType) } + + @Test + fun should_handle_for_node_type_in_function_call() { + val code = """ + use std::sync::Arc; + + pub use embedding::Semantic; + pub use embedding::semantic::SemanticError; + + pub fn init_semantic(model: Vec, tokenizer_data: Vec) -> Result, SemanticError> { + let result = Semantic::init_semantic(model, tokenizer_data)?; + Ok(Arc::new(result)) + } + + pub fn embed() -> Embedding { + let model = std::fs::read("../model/model.onnx").unwrap(); + let tokenizer_data = std::fs::read("../model/tokenizer.json").unwrap(); + + let semantic = init_semantic(model, tokenizer_data).unwrap(); + semantic.embed("hello world").unwrap() + } + """.trimIndent() + + val codeContainer = RustAnalyser().analysis(code, "lib.rs") + val codeDataStruct = codeContainer.DataStructures + val embedFunc = codeDataStruct[0].Functions[1] + + val functionCalls = embedFunc.FunctionCalls + val outputs = functionCalls.map { + "${it.NodeName} -> ${it.FunctionName} -> ${it.OriginNodeName}" + }.joinToString("\n") + + assertEquals(8, functionCalls.size) + assertEquals(outputs, """ + std::fs::read -> unwrap -> std::fs::read + std::fs -> read -> std::fs + std::fs::read -> unwrap -> std::fs::read + std::fs -> read -> std::fs + embedding::Semantic -> unwrap -> init_semantic + embedding::Semantic -> init_semantic -> init_semantic + semantic.embed -> unwrap -> semantic.embed + semantic -> embed -> semantic + """.trimIndent()) + } }