@@ -2,6 +2,7 @@ package cc.unitmesh.cf.domains.spec
2
2
3
3
import cc.unitmesh.nlp.embedding.Embedding
4
4
import cc.unitmesh.nlp.embedding.EmbeddingProvider
5
+ import cc.unitmesh.nlp.embedding.EncodingTokenizer
5
6
import cc.unitmesh.rag.document.Document
6
7
import cc.unitmesh.rag.retriever.EmbeddingStoreRetriever
7
8
import cc.unitmesh.rag.splitter.MarkdownHeaderTextSplitter
@@ -15,40 +16,53 @@ class SpecRelevantSearch(val embeddingProvider: EmbeddingProvider) {
15
16
private lateinit var vectorStoreRetriever: EmbeddingStoreRetriever
16
17
17
18
// cached for performance
18
- private val searchCache: MutableMap <String , List <String >> = mutableMapOf ()
19
+ private val searchCache: MutableMap <String , List <SearchResult >> = mutableMapOf ()
19
20
20
21
init {
21
22
val text = javaClass.getResourceAsStream(" /be/specification.md" )!! .bufferedReader().readText()
22
23
val headersToSplitOn: List <Pair <String , String >> = listOf (
23
- Pair (" #" , " Header 1" ),
24
- Pair (" ##" , " Header 2" ),
25
- Pair (" ###" , " Header 3" ),
24
+ Pair (" #" , " H1" ),
25
+ Pair (" ##" , " H2" ),
26
26
)
27
27
28
28
val documents = MarkdownHeaderTextSplitter (headersToSplitOn)
29
29
.splitText(text)
30
30
31
- val documentList = TokenTextSplitter (chunkSize = 384 ).apply (documents)
31
+ val documentList = documents.map {
32
+ val header = " ${it.metadata[" H1" ]} > ${it.metadata[" H2" ]} "
33
+ val withHeader = it.copy(text = " $header ${it.text} " )
34
+ TokenTextSplitter (chunkSize = 384 ).apply (listOf (withHeader)).first()
35
+ }
32
36
33
37
val vectorStore: EmbeddingStore <Document > = InMemoryEmbeddingStore ()
34
38
val embeddings: List <Embedding > = documentList.map {
35
39
embeddingProvider.embed(it.text)
36
40
}
37
41
vectorStore.addAll(embeddings, documentList)
38
42
39
- this .vectorStoreRetriever = EmbeddingStoreRetriever (vectorStore)
43
+ this .vectorStoreRetriever = EmbeddingStoreRetriever (vectorStore, 5 , 0.6 )
40
44
}
41
45
42
46
// TODO: change to search engine
43
- fun search (query : String ): List <String > {
47
+ fun search (query : String ): List <SearchResult > {
44
48
if (searchCache.containsKey(query)) {
45
49
return searchCache[query]!!
46
50
}
47
51
48
52
val queryEmbedding = embeddingProvider.embed(query)
49
53
val similarDocuments = vectorStoreRetriever.retrieve(queryEmbedding)
50
- val results = similarDocuments.map { it.embedded.text }
54
+ val results = similarDocuments.map {
55
+ SearchResult (
56
+ source = it.embedded.metadata.toString(),
57
+ content = it.embedded.text
58
+ )
59
+ }
51
60
searchCache[query] = results
52
61
return results
53
62
}
54
- }
63
+ }
64
+
65
+ data class SearchResult (
66
+ val source : String ,
67
+ val content : String
68
+ )
0 commit comments