From 72efa429d6726c945760de79885e8ed9f1ef5185 Mon Sep 17 00:00:00 2001 From: Frikky Date: Fri, 17 May 2024 01:25:21 +0200 Subject: [PATCH] Optimized schemaless more to use cache further --- shuffle.go | 29 +++++++++++++++++++------- translate.go | 57 +++++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 72 insertions(+), 14 deletions(-) diff --git a/shuffle.go b/shuffle.go index 960e23f..164f83b 100644 --- a/shuffle.go +++ b/shuffle.go @@ -13,7 +13,7 @@ import ( "net/url" "net/http" "io/ioutil" - "math/rand" + //"math/rand" "crypto/tls" "crypto/md5" "encoding/hex" @@ -154,6 +154,19 @@ func AddShuffleFile(name, namespace string, data []byte, shuffleConfig ShuffleCo fileUrl += "&execution_id=" + shuffleConfig.ExecutionId } + // Check if the file has already been uploaded based on shuffleConfig.OrgId+namespace+data. No point in overwriting with the same data. + hasher := md5.New() + ctx := context.Background() + hasher.Write([]byte(fmt.Sprintf("%s%s%s%s", shuffleConfig.OrgId, name, namespace, string(data)))) + cacheKey := hex.EncodeToString(hasher.Sum(nil)) + cache, err := GetCache(ctx, cacheKey) + if err == nil { + cacheData := []byte(cache.([]uint8)) + if len(cacheData) > 0 { + return nil + } + } + fileDataJson, err := json.Marshal(fileData) if err != nil { return err @@ -275,6 +288,12 @@ func AddShuffleFile(name, namespace string, data []byte, shuffleConfig ShuffleCo return err } + // Update with basically nothing, as the point isn't to get the file itself + err = SetCache(ctx, cacheKey, []byte("1"), 10) + if err != nil { + log.Printf("[ERROR] Schemaless (8): Error setting cache for file %#v from Shuffle backend: %s", name, err) + } + return nil } @@ -294,8 +313,6 @@ func GetShuffleFileById(id string, shuffleConfig ShuffleConfig) ([]byte, error) cacheKey := hex.EncodeToString(hasher.Sum(nil)) // The file will be grabbed a ton, hence the cache actually speeding things up and reducing requests - sleepTime := time.Duration(25 + rand.Intn(1000-25)) * time.Millisecond - time.Sleep(sleepTime) cache, err := GetCache(ctx, cacheKey) if err == nil { @@ -352,7 +369,6 @@ func FindShuffleFile(name, category string, shuffleConfig ShuffleConfig) ([]byte } - log.Printf("[INFO] Schemaless: Finding file %#v in category %#v from Shuffle backend", name, category) // 1. Get the category // 2. Find the file in the category output @@ -365,17 +381,16 @@ func FindShuffleFile(name, category string, shuffleConfig ShuffleConfig) ([]byte hasher.Write([]byte(categoryUrl+shuffleConfig.Authorization+shuffleConfig.OrgId+shuffleConfig.ExecutionId)) cacheKey := hex.EncodeToString(hasher.Sum(nil)) - sleepTime := time.Duration(25 + rand.Intn(1000-25)) * time.Millisecond - time.Sleep(sleepTime) - // Get the cache ctx := context.Background() var body []byte cache, err := GetCache(ctx, cacheKey) if err == nil { + //log.Printf("[INFO] Schemaless: FOUND file %#v in category %#v from cache", name, category) body = []byte(cache.([]uint8)) //return cacheData, nil } else { + log.Printf("[INFO] Schemaless: Finding file %#v in category %#v from Shuffle backend", name, category) if len(shuffleConfig.ExecutionId) > 0 { categoryUrl += "&execution_id=" + shuffleConfig.ExecutionId } diff --git a/translate.go b/translate.go index 5cbf2d4..cac5b75 100644 --- a/translate.go +++ b/translate.go @@ -28,6 +28,7 @@ var maxInputSize = 4000 func SaveQuery(inputStandard, gptTranslated string, shuffleConfig ShuffleConfig) error { if len(shuffleConfig.URL) > 0 { + //return nil return AddShuffleFile(inputStandard, "translation_ai_queries", []byte(gptTranslated), shuffleConfig) } @@ -58,7 +59,6 @@ func GptTranslate(keyTokenFile, standardFormat, inputDataFormat string, shuffleC systemMessage := fmt.Sprintf("Ensure the output is valid JSON, and does NOT add more keys to the standard. Make sure each key in the standard has a value from the user input. If values are nested, ALWAYS add the nested value in jq format such as 'secret.version.value'. Example: If the standard is ```{\"id\": \"The id of the ticket\", \"title\": \"The ticket title\"}```, and the user input is ```{\"key\": \"12345\", \"fields:\": {\"summary\": \"The title of the ticket\"}}```, the output should be ```{\"id\": \"key\", \"title\": \"fields.summary\"}```") - log.Printf("[DEBUG] Schemaless: Running GPT with system message: %s", systemMessage) userQuery := fmt.Sprintf("Translate the given user input JSON structure to a standard format. Use the values from the standard to guide you what to look for. The standard format should follow the pattern:\n\n```json\n%s\n```\n\nUser Input:\n```json\n%s\n```\n\nGenerate the standard output structure without providing the expected output.", standardFormat, inputDataFormat) @@ -70,6 +70,19 @@ func GptTranslate(keyTokenFile, standardFormat, inputDataFormat string, shuffleC return standardFormat, errors.New(fmt.Sprintf("Input data too long. Max is %d. Current is %d", maxInputSize, len(inputDataFormat))) } + // Make md5 of the query, and put it in cache to check + ctx := context.Background() + md5Query := fmt.Sprintf("%x", md5.Sum([]byte(shuffleConfig.OrgId+systemMessage+userQuery))) + + cacheKey := fmt.Sprintf("translationquery-%s", md5Query) + cache, err := GetCache(ctx, cacheKey) + if err == nil { + contentOutput := string([]byte(cache.([]uint8))) + return contentOutput, nil + } + + log.Printf("[DEBUG] Schemaless: Running GPT with system message: %s", systemMessage) + SaveQuery(keyTokenFile, userQuery, shuffleConfig) openaiClient := openai.NewClient(os.Getenv("OPENAI_API_KEY")) @@ -111,6 +124,12 @@ func GptTranslate(keyTokenFile, standardFormat, inputDataFormat string, shuffleC break } + err = SetCache(ctx, cacheKey, []byte(contentOutput), 30) + if err != nil { + log.Printf("[ERROR] Schemaless: Error setting cache for key %s: %v", cacheKey, err) + return contentOutput, err + } + return contentOutput, nil } @@ -361,7 +380,8 @@ func SaveTranslation(inputStandard, gptTranslated string, shuffleConfig ShuffleC gptTranslated = FixTranslationStructure(gptTranslated) if len(shuffleConfig.URL) > 0 { - go AddShuffleFile(inputStandard, "translation_output", []byte(gptTranslated), shuffleConfig) + // Used to be a goroutine + return AddShuffleFile(inputStandard, "translation_output", []byte(gptTranslated), shuffleConfig) return nil } @@ -387,6 +407,8 @@ func SaveTranslation(inputStandard, gptTranslated string, shuffleConfig ShuffleC func SaveParsedInput(inputStandard string, gptTranslated []byte, shuffleConfig ShuffleConfig) error { if len(shuffleConfig.URL) > 0 { + // FIXME: Should we upload everything? I think not + return nil return AddShuffleFile(inputStandard, "translation_input", gptTranslated, shuffleConfig) } @@ -644,12 +666,33 @@ func handleSubStandard(ctx context.Context, subStandard string, returnJson strin // For each item in the list, translate it to the substandard // Maybe do this with recursive Translate() calls? + + skipAfterCount := 50 var wg sync.WaitGroup var mu sync.Mutex // Mutex to safely access parsedOutput slice + parsedOutput := [][]byte{} for cnt, listItem := range listJson { - wg.Add(1) // Increment the wait group counter for each goroutine + // No goroutine on the first ones as we want to make sure caching is done properly + if cnt == 0 { + marshalledBody, err := json.Marshal(listItem) + if err != nil { + log.Printf("[ERROR] Schemaless: Error in marshalling of list item: %v", err) + continue + } + + schemalessOutput, err := Translate(ctx, subStandard, marshalledBody, authConfig, "skip_substandard") + if err != nil { + log.Printf("[ERROR] Schemaless: Error in schemaless.Translate for sub list item: %v", err) + continue + } + + parsedOutput = append(parsedOutput, schemalessOutput) + continue + } + + wg.Add(1) // Increment the wait group counter for each goroutine go func(cnt int, listItem interface{}) { defer wg.Done() // Decrement the wait group counter when the goroutine completes @@ -671,8 +714,8 @@ func handleSubStandard(ctx context.Context, subStandard string, returnJson strin parsedOutput = append(parsedOutput, schemalessOutput) }(cnt, listItem) - if cnt > 50 { - log.Printf("[WARNING] Schemaless: Breaking after 10 items in the list") + if cnt > skipAfterCount { + log.Printf("[WARNING] Schemaless: Breaking after %d items in the list", skipAfterCount) break } } @@ -830,9 +873,9 @@ func Translate(ctx context.Context, inputStandard string, inputValue []byte, inp if err != nil { log.Printf("[ERROR] Error in SaveTranslation (3): %v", err) return []byte{}, err - } + } - log.Printf("[DEBUG] Saved translation to file. Should now run OpenAI and set cache!") + //log.Printf("[DEBUG] Saved GPT translation to file. Should now run OpenAI and set cache!") inputStructure = []byte(gptTranslated) }