From 8e7ee9643269c78427ac0ea9a9e29c5f3e7c7077 Mon Sep 17 00:00:00 2001 From: rinfx <893383980@qq.com> Date: Wed, 16 Oct 2024 10:56:59 +0800 Subject: [PATCH 1/2] special charactor handle bug fix --- .../extensions/ai-security-guard/main.go | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index 3051fe22ea..26582ecb3b 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -81,7 +81,7 @@ func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) func urlEncoding(rawStr string) string { encodedStr := url.PathEscape(rawStr) - encodedStr = strings.ReplaceAll(encodedStr, "+", "%20") + encodedStr = strings.ReplaceAll(encodedStr, "+", "%2B") encodedStr = strings.ReplaceAll(encodedStr, ":", "%3A") encodedStr = strings.ReplaceAll(encodedStr, "=", "%3D") encodedStr = strings.ReplaceAll(encodedStr, "&", "%26") @@ -106,7 +106,7 @@ func getSign(params map[string]string, secret string) string { }) canonicalStr := strings.Join(paramArray, "&") signStr := "POST&%2F&" + urlEncoding(canonicalStr) - // proxywasm.LogInfo(signStr) + proxywasm.LogDebugf("String to sign is:" + signStr) return hmacSha1(signStr, secret) } @@ -196,10 +196,11 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { log.Debugf("checking request body...") - content := gjson.GetBytes(body, config.requestContentJsonPath).String() + content := gjson.GetBytes(body, config.requestContentJsonPath).Raw model := gjson.GetBytes(body, "model").Raw ctx.SetContext("requestModel", model) - if content != "" { + log.Debugf("Raw response content is: %s", content) + if len(content) > 0 { timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") randomID, _ := generateHexID(16) params := map[string]string{ @@ -212,7 +213,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] "AccessKeyId": config.ak, "Timestamp": timestamp, "Service": config.requestCheckService, - "ServiceParameters": `{"content": "` + content + `"}`, + "ServiceParameters": fmt.Sprintf(`{"content": %s}`, content), } signature := getSign(params, config.sk+"&") reqParams := url.Values{} @@ -339,7 +340,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ if isStreamingResponse { content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath) } else { - content = gjson.GetBytes(body, config.responseContentJsonPath).String() + content = gjson.GetBytes(body, config.responseContentJsonPath).Raw } log.Debugf("Raw response content is: %s", content) if len(content) > 0 { @@ -355,7 +356,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ "AccessKeyId": config.ak, "Timestamp": timestamp, "Service": config.responseCheckService, - "ServiceParameters": `{"content": "` + content + `"}`, + "ServiceParameters": fmt.Sprintf(`{"content": %s}`, content), } signature := getSign(params, config.sk+"&") reqParams := url.Values{} @@ -432,10 +433,10 @@ func extractMessageFromStreamingBody(data []byte, jsonPath string) string { strChunks := []string{} for _, chunk := range chunks { // Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}] - jsonObj := gjson.GetBytes(chunk, jsonPath) - if jsonObj.Exists() { - strChunks = append(strChunks, jsonObj.String()) + jsonRaw := gjson.GetBytes(chunk, jsonPath).Raw + if len(jsonRaw) > 2 { + strChunks = append(strChunks, jsonRaw[1:len(jsonRaw)-1]) } } - return strings.Join(strChunks, "") + return fmt.Sprintf(`"%s"`, strings.Join(strChunks, "")) } From b9279e17dce6969e4afd516821689c12fd1a74ce Mon Sep 17 00:00:00 2001 From: rinfx <893383980@qq.com> Date: Thu, 17 Oct 2024 14:07:10 +0800 Subject: [PATCH 2/2] log optimized & crash bugfix --- plugins/wasm-go/extensions/ai-security-guard/main.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index 26582ecb3b..ca59f7f6a1 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -106,7 +106,7 @@ func getSign(params map[string]string, secret string) string { }) canonicalStr := strings.Join(paramArray, "&") signStr := "POST&%2F&" + urlEncoding(canonicalStr) - proxywasm.LogDebugf("String to sign is:" + signStr) + proxywasm.LogDebugf("String to sign is: %s", signStr) return hmacSha1(signStr, secret) } @@ -401,10 +401,10 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ jsonData = []byte(denyMessage) } else if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") { randomID := generateRandomID() - jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String(), randomID, model)) + jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model)) } else { randomID := generateRandomID() - jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String())) + jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, denyMessage)) } delete(hdsMap, "content-length") hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}