Skip to content

Commit 208775d

Browse files
committed
feat: make spec workflow works
1 parent 2173bd1 commit 208775d

File tree

6 files changed

+51
-22
lines changed

6 files changed

+51
-22
lines changed

cocoa-core/src/main/kotlin/cc/unitmesh/rag/splitter/EncodingTokenizer.kt cocoa-core/src/main/kotlin/cc/unitmesh/nlp/embedding/EncodingTokenizer.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package cc.unitmesh.rag.splitter
1+
package cc.unitmesh.nlp.embedding
22

33
interface EncodingTokenizer {
44
fun encode(text: String): List<Int>

cocoa-core/src/main/kotlin/cc/unitmesh/rag/splitter/OpenAiEncoding.kt cocoa-core/src/main/kotlin/cc/unitmesh/nlp/embedding/OpenAiEncoding.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package cc.unitmesh.rag.splitter
1+
package cc.unitmesh.nlp.embedding
22

33
import com.knuddels.jtokkit.Encodings
44
import com.knuddels.jtokkit.api.Encoding

cocoa-core/src/main/kotlin/cc/unitmesh/rag/splitter/TokenTextSplitter.kt

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
*/
1818
package cc.unitmesh.rag.splitter
1919

20+
import cc.unitmesh.nlp.embedding.EncodingTokenizer
21+
import cc.unitmesh.nlp.embedding.OpenAiEncoding
2022
import kotlin.math.max
2123
import kotlin.math.min
2224

src/main/kotlin/cc/unitmesh/cf/domains/spec/SpecRelevantSearch.kt

+23-9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cc.unitmesh.cf.domains.spec
22

33
import cc.unitmesh.nlp.embedding.Embedding
44
import cc.unitmesh.nlp.embedding.EmbeddingProvider
5+
import cc.unitmesh.nlp.embedding.EncodingTokenizer
56
import cc.unitmesh.rag.document.Document
67
import cc.unitmesh.rag.retriever.EmbeddingStoreRetriever
78
import cc.unitmesh.rag.splitter.MarkdownHeaderTextSplitter
@@ -15,40 +16,53 @@ class SpecRelevantSearch(val embeddingProvider: EmbeddingProvider) {
1516
private lateinit var vectorStoreRetriever: EmbeddingStoreRetriever
1617

1718
// cached for performance
18-
private val searchCache: MutableMap<String, List<String>> = mutableMapOf()
19+
private val searchCache: MutableMap<String, List<SearchResult>> = mutableMapOf()
1920

2021
init {
2122
val text = javaClass.getResourceAsStream("/be/specification.md")!!.bufferedReader().readText()
2223
val headersToSplitOn: List<Pair<String, String>> = listOf(
23-
Pair("#", "Header 1"),
24-
Pair("##", "Header 2"),
25-
Pair("###", "Header 3"),
24+
Pair("#", "H1"),
25+
Pair("##", "H2"),
2626
)
2727

2828
val documents = MarkdownHeaderTextSplitter(headersToSplitOn)
2929
.splitText(text)
3030

31-
val documentList = TokenTextSplitter(chunkSize = 384).apply(documents)
31+
val documentList = documents.map {
32+
val header = "${it.metadata["H1"]} > ${it.metadata["H2"]}"
33+
val withHeader = it.copy(text = "$header ${it.text}")
34+
TokenTextSplitter(chunkSize = 384).apply(listOf(withHeader)).first()
35+
}
3236

3337
val vectorStore: EmbeddingStore<Document> = InMemoryEmbeddingStore()
3438
val embeddings: List<Embedding> = documentList.map {
3539
embeddingProvider.embed(it.text)
3640
}
3741
vectorStore.addAll(embeddings, documentList)
3842

39-
this.vectorStoreRetriever = EmbeddingStoreRetriever(vectorStore)
43+
this.vectorStoreRetriever = EmbeddingStoreRetriever(vectorStore, 5, 0.6)
4044
}
4145

4246
// TODO: change to search engine
43-
fun search(query: String): List<String> {
47+
fun search(query: String): List<SearchResult> {
4448
if (searchCache.containsKey(query)) {
4549
return searchCache[query]!!
4650
}
4751

4852
val queryEmbedding = embeddingProvider.embed(query)
4953
val similarDocuments = vectorStoreRetriever.retrieve(queryEmbedding)
50-
val results = similarDocuments.map { it.embedded.text }
54+
val results = similarDocuments.map {
55+
SearchResult(
56+
source = it.embedded.metadata.toString(),
57+
content = it.embedded.text
58+
)
59+
}
5160
searchCache[query] = results
5261
return results
5362
}
54-
}
63+
}
64+
65+
data class SearchResult(
66+
val source: String,
67+
val content: String
68+
)

src/main/kotlin/cc/unitmesh/cf/domains/spec/SpecWorkflow.kt

+22-8
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,18 @@ class SpecWorkflow : Workflow() {
2626
)
2727

2828
override fun execute(prompt: StageContext, chatWebContext: ChatWebContext): Flowable<WorkflowResult> {
29-
// TODO clarify user question, 如系统包含了这些规范,你需要哪些规范?
30-
val specs = relevantSearch.search(chatWebContext.messages.last {
29+
val question = chatWebContext.messages.last {
3130
it.role == LlmMsg.ChatRole.User.value
32-
}.content)
31+
}.content
32+
33+
// TODO clarify user question, 如系统包含了这些规范,你需要哪些规范?
34+
val specs = relevantSearch.search(question)
3335

3436
val userMsg = EXECUTE.format()
35-
.replace("${'$'}{specs}", specs.joinToString("\n"))
36-
.replace("${'$'}{question}", chatWebContext.messages[0].content)
37+
.replace("${'$'}{specs}", specs.map {
38+
"source: ${it.source} content: ${it.content}"
39+
}.joinToString("\n"))
40+
.replace("${'$'}{question}", question)
3741

3842
val flowable = llmProvider.streamCompletion(listOf(
3943
LlmMsg.ChatMessage(LlmMsg.ChatRole.System, userMsg),
@@ -67,17 +71,27 @@ class SpecWorkflow : Workflow() {
6771
|
6872
|- 如果规范缺少对应的信息,你不要回答。
6973
|- 你必须回答用户的问题。
74+
|- 请根据客户的问题,返回对应的规范,并返回对应的 source 相关信息。
75+
|
7076
|
7177
|已有规范信息:
7278
|
7379
|```design
7480
|${'$'}{specs}
7581
|```
76-
|用户的问题:
77-
|${'$'}{question}
7882
|
79-
|现在请你根据规范信息,回答用户的问题。
83+
|示例:
84+
|用户的问题:哪些规范包含了架构设计?
85+
|回答:
86+
|###
87+
|
88+
|出处:后端代码规范的命名规范章节 // 这里根据规范的 source 项信息,写出对应的来源
89+
|// 这里,你需要返回是规范中的详细信息,而不是规范的标题。
90+
|###
8091
|
92+
|现在请你根据规范信息,回答用户的问题。
93+
|用户的问题:
94+
|${'$'}{question}
8195
|""".trimMargin()
8296
)
8397
}

src/main/kotlin/cc/unitmesh/cf/infrastructure/llms/embedding/SentenceTransformersEmbedding.kt

+2-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import org.springframework.stereotype.Component
44
import cc.unitmesh.cf.STSemantic
55
import cc.unitmesh.nlp.embedding.Embedding
66
import cc.unitmesh.nlp.embedding.EmbeddingProvider
7-
import cc.unitmesh.rag.splitter.EncodingTokenizer
7+
import cc.unitmesh.nlp.embedding.EncodingTokenizer
88

99
@Component
1010
class SentenceTransformersEmbedding : EmbeddingProvider, EncodingTokenizer {
@@ -26,8 +26,7 @@ class SentenceTransformersEmbedding : EmbeddingProvider, EncodingTokenizer {
2626

2727
override fun decode(tokens: List<Int>): String {
2828
val map = tokens.map { it.toLong() }.toLongArray()
29-
val output = tokenizer.decode(map)
3029
// output will be "[CLS] blog [SEP]" for input "blog", so we need to remove the first and last token
31-
return output
30+
return tokenizer.decode(map)
3231
}
3332
}

0 commit comments

Comments
 (0)