Skip to content

Commit

Permalink
feat(rust): add handle for function calling in type
Browse files Browse the repository at this point in the history
  • Loading branch information
phodal committed Jan 2, 2024
1 parent 9708673 commit 7585551
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,22 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis
*/
var localVars: MutableMap<String, String> = mutableMapOf()

/**
* localFunctions will store all local functions in the current scope
*
* for individual function: currentFunction.Name
* for Implementation: currentNode.NodeName + "::" + currentFunction.Name
*
*/
var localFunctions: MutableMap<String, CodeFunction> = mutableMapOf()

/**
* packageName will parse from fileName, like:
* - "src/main.rs" -> "main"
* - "enfer_core/src/lib.rs" -> "enfer_core"
* - "enfer_core/src/document.rs" -> "enfer_core::document"
*/
private val packageName: String
val packageName: String
get() {
val modulePath = fileName.substringBeforeLast("src")
.substringBeforeLast(File.separator)
Expand Down Expand Up @@ -195,7 +204,7 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis

open fun lookupByType(typeText: String?): String {
if (typeText == null) return ""
val text = if (typeText.contains("::")) {
var text = if (typeText.contains("::")) {
val crateName = typeText.split("::").first()
if (crateName == "std" || crateName == "crate") {
typeText
Expand All @@ -206,6 +215,19 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis
typeText
}

if (localVars.containsKey(text)) {
return localVars[text] ?: ""
}

if (localFunctions.containsKey(text)) {
val function = localFunctions[text]!!
text = if (function.MultipleReturns.isNotEmpty()) {
function.MultipleReturns.first().TypeType
} else {
function.ReturnType
}
}

imports.filter { it.AsName == text }.forEach {
return it.Source
}
Expand Down Expand Up @@ -388,7 +410,7 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis
generics + typeList.map {
CodeProperty(
TypeType = lookupType(it),
TypeValue = lookupType(it)
TypeValue = it.text,
)
}
}.flatten()
Expand All @@ -410,8 +432,10 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis
if (isEnteredImplementation == false) {
isEnteredIndividualFunction = false
individualFunctions.add(currentIndividualFunction)
localFunctions[currentIndividualFunction.Name] = currentIndividualFunction
} else {
currentNode.Functions += currentFunction
localFunctions[currentNode.NodeName + "::" + currentFunction.Name] = currentFunction
}

localVars.clear()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,15 @@ class RustFullIdentListener(fileName: String) : RustAstBaseListener(fileName) {
override fun enterCallExpression(ctx: RustParser.CallExpressionContext?) {
val functionName = ctx?.expression()?.text
val split = functionName?.split("::")
val nodeName = split?.dropLast(1)?.joinToString("::") ?: ""
val lastType = split?.dropLast(1)?.joinToString("::")
val nodeName = if (lastType.isNullOrEmpty()) {
split?.firstOrNull() ?: ""
} else {
lastType
}

functionInstance.FunctionCalls += CodeCall(
Package = split?.dropLast(1)?.joinToString("::") ?: "",
Package = lastType ?: packageName,
NodeName = lookupByType(nodeName),
FunctionName = split?.last() ?: "",
OriginNodeName = nodeName,
Expand All @@ -52,7 +58,7 @@ class RustFullIdentListener(fileName: String) : RustAstBaseListener(fileName) {

// todo: handle method call
functionInstance.FunctionCalls += CodeCall(
Package = nodeName,
Package = packageName,
NodeName = lookupByType(nodeName),
OriginNodeName = instanceVar.ifEmpty { nodeName },
FunctionName = functionName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,4 +485,48 @@ fn main() {
assertEquals("std::sync::Arc", firstFunction.MultipleReturns[1].TypeType)
assertEquals("embedding::semantic::SemanticError", firstFunction.MultipleReturns[2].TypeType)
}

@Test
fun should_handle_for_node_type_in_function_call() {
val code = """
use std::sync::Arc;
pub use embedding::Semantic;
pub use embedding::semantic::SemanticError;
pub fn init_semantic(model: Vec<u8>, tokenizer_data: Vec<u8>) -> Result<Arc<Semantic>, SemanticError> {
let result = Semantic::init_semantic(model, tokenizer_data)?;
Ok(Arc::new(result))
}
pub fn embed() -> Embedding {
let model = std::fs::read("../model/model.onnx").unwrap();
let tokenizer_data = std::fs::read("../model/tokenizer.json").unwrap();
let semantic = init_semantic(model, tokenizer_data).unwrap();
semantic.embed("hello world").unwrap()
}
""".trimIndent()

val codeContainer = RustAnalyser().analysis(code, "lib.rs")
val codeDataStruct = codeContainer.DataStructures
val embedFunc = codeDataStruct[0].Functions[1]

val functionCalls = embedFunc.FunctionCalls
val outputs = functionCalls.map {
"${it.NodeName} -> ${it.FunctionName} -> ${it.OriginNodeName}"
}.joinToString("\n")

assertEquals(8, functionCalls.size)
assertEquals(outputs, """
std::fs::read -> unwrap -> std::fs::read
std::fs -> read -> std::fs
std::fs::read -> unwrap -> std::fs::read
std::fs -> read -> std::fs
embedding::Semantic -> unwrap -> init_semantic
embedding::Semantic -> init_semantic -> init_semantic
semantic.embed -> unwrap -> semantic.embed
semantic -> embed -> semantic
""".trimIndent())
}
}

0 comments on commit 7585551

Please # to comment.