Skip to content

Commit

Permalink
Optimized schemaless more to use cache further
Browse files Browse the repository at this point in the history
  • Loading branch information
frikky committed May 16, 2024
1 parent 131962f commit 72efa42
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 14 deletions.
29 changes: 22 additions & 7 deletions shuffle.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"net/url"
"net/http"
"io/ioutil"
"math/rand"
//"math/rand"
"crypto/tls"
"crypto/md5"
"encoding/hex"
Expand Down Expand Up @@ -154,6 +154,19 @@ func AddShuffleFile(name, namespace string, data []byte, shuffleConfig ShuffleCo
fileUrl += "&execution_id=" + shuffleConfig.ExecutionId
}

// Check if the file has already been uploaded based on shuffleConfig.OrgId+namespace+data. No point in overwriting with the same data.
hasher := md5.New()
ctx := context.Background()
hasher.Write([]byte(fmt.Sprintf("%s%s%s%s", shuffleConfig.OrgId, name, namespace, string(data))))
cacheKey := hex.EncodeToString(hasher.Sum(nil))
cache, err := GetCache(ctx, cacheKey)
if err == nil {
cacheData := []byte(cache.([]uint8))
if len(cacheData) > 0 {
return nil
}
}

fileDataJson, err := json.Marshal(fileData)
if err != nil {
return err
Expand Down Expand Up @@ -275,6 +288,12 @@ func AddShuffleFile(name, namespace string, data []byte, shuffleConfig ShuffleCo
return err
}

// Update with basically nothing, as the point isn't to get the file itself
err = SetCache(ctx, cacheKey, []byte("1"), 10)
if err != nil {
log.Printf("[ERROR] Schemaless (8): Error setting cache for file %#v from Shuffle backend: %s", name, err)
}

return nil
}

Expand All @@ -294,8 +313,6 @@ func GetShuffleFileById(id string, shuffleConfig ShuffleConfig) ([]byte, error)
cacheKey := hex.EncodeToString(hasher.Sum(nil))

// The file will be grabbed a ton, hence the cache actually speeding things up and reducing requests
sleepTime := time.Duration(25 + rand.Intn(1000-25)) * time.Millisecond
time.Sleep(sleepTime)

cache, err := GetCache(ctx, cacheKey)
if err == nil {
Expand Down Expand Up @@ -352,7 +369,6 @@ func FindShuffleFile(name, category string, shuffleConfig ShuffleConfig) ([]byte
}


log.Printf("[INFO] Schemaless: Finding file %#v in category %#v from Shuffle backend", name, category)

// 1. Get the category
// 2. Find the file in the category output
Expand All @@ -365,17 +381,16 @@ func FindShuffleFile(name, category string, shuffleConfig ShuffleConfig) ([]byte
hasher.Write([]byte(categoryUrl+shuffleConfig.Authorization+shuffleConfig.OrgId+shuffleConfig.ExecutionId))
cacheKey := hex.EncodeToString(hasher.Sum(nil))

sleepTime := time.Duration(25 + rand.Intn(1000-25)) * time.Millisecond
time.Sleep(sleepTime)

// Get the cache
ctx := context.Background()
var body []byte
cache, err := GetCache(ctx, cacheKey)
if err == nil {
//log.Printf("[INFO] Schemaless: FOUND file %#v in category %#v from cache", name, category)
body = []byte(cache.([]uint8))
//return cacheData, nil
} else {
log.Printf("[INFO] Schemaless: Finding file %#v in category %#v from Shuffle backend", name, category)
if len(shuffleConfig.ExecutionId) > 0 {
categoryUrl += "&execution_id=" + shuffleConfig.ExecutionId
}
Expand Down
57 changes: 50 additions & 7 deletions translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ var maxInputSize = 4000

func SaveQuery(inputStandard, gptTranslated string, shuffleConfig ShuffleConfig) error {
if len(shuffleConfig.URL) > 0 {
//return nil
return AddShuffleFile(inputStandard, "translation_ai_queries", []byte(gptTranslated), shuffleConfig)
}

Expand Down Expand Up @@ -58,7 +59,6 @@ func GptTranslate(keyTokenFile, standardFormat, inputDataFormat string, shuffleC

systemMessage := fmt.Sprintf("Ensure the output is valid JSON, and does NOT add more keys to the standard. Make sure each key in the standard has a value from the user input. If values are nested, ALWAYS add the nested value in jq format such as 'secret.version.value'. Example: If the standard is ```{\"id\": \"The id of the ticket\", \"title\": \"The ticket title\"}```, and the user input is ```{\"key\": \"12345\", \"fields:\": {\"summary\": \"The title of the ticket\"}}```, the output should be ```{\"id\": \"key\", \"title\": \"fields.summary\"}```")

log.Printf("[DEBUG] Schemaless: Running GPT with system message: %s", systemMessage)

userQuery := fmt.Sprintf("Translate the given user input JSON structure to a standard format. Use the values from the standard to guide you what to look for. The standard format should follow the pattern:\n\n```json\n%s\n```\n\nUser Input:\n```json\n%s\n```\n\nGenerate the standard output structure without providing the expected output.", standardFormat, inputDataFormat)

Expand All @@ -70,6 +70,19 @@ func GptTranslate(keyTokenFile, standardFormat, inputDataFormat string, shuffleC
return standardFormat, errors.New(fmt.Sprintf("Input data too long. Max is %d. Current is %d", maxInputSize, len(inputDataFormat)))
}

// Make md5 of the query, and put it in cache to check
ctx := context.Background()
md5Query := fmt.Sprintf("%x", md5.Sum([]byte(shuffleConfig.OrgId+systemMessage+userQuery)))

cacheKey := fmt.Sprintf("translationquery-%s", md5Query)
cache, err := GetCache(ctx, cacheKey)
if err == nil {
contentOutput := string([]byte(cache.([]uint8)))
return contentOutput, nil
}

log.Printf("[DEBUG] Schemaless: Running GPT with system message: %s", systemMessage)

SaveQuery(keyTokenFile, userQuery, shuffleConfig)

openaiClient := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
Expand Down Expand Up @@ -111,6 +124,12 @@ func GptTranslate(keyTokenFile, standardFormat, inputDataFormat string, shuffleC
break
}

err = SetCache(ctx, cacheKey, []byte(contentOutput), 30)
if err != nil {
log.Printf("[ERROR] Schemaless: Error setting cache for key %s: %v", cacheKey, err)
return contentOutput, err
}

return contentOutput, nil
}

Expand Down Expand Up @@ -361,7 +380,8 @@ func SaveTranslation(inputStandard, gptTranslated string, shuffleConfig ShuffleC
gptTranslated = FixTranslationStructure(gptTranslated)

if len(shuffleConfig.URL) > 0 {
go AddShuffleFile(inputStandard, "translation_output", []byte(gptTranslated), shuffleConfig)
// Used to be a goroutine
return AddShuffleFile(inputStandard, "translation_output", []byte(gptTranslated), shuffleConfig)
return nil
}

Expand All @@ -387,6 +407,8 @@ func SaveTranslation(inputStandard, gptTranslated string, shuffleConfig ShuffleC

func SaveParsedInput(inputStandard string, gptTranslated []byte, shuffleConfig ShuffleConfig) error {
if len(shuffleConfig.URL) > 0 {
// FIXME: Should we upload everything? I think not
return nil
return AddShuffleFile(inputStandard, "translation_input", gptTranslated, shuffleConfig)
}

Expand Down Expand Up @@ -644,12 +666,33 @@ func handleSubStandard(ctx context.Context, subStandard string, returnJson strin

// For each item in the list, translate it to the substandard
// Maybe do this with recursive Translate() calls?

skipAfterCount := 50
var wg sync.WaitGroup
var mu sync.Mutex // Mutex to safely access parsedOutput slice

parsedOutput := [][]byte{}
for cnt, listItem := range listJson {
wg.Add(1) // Increment the wait group counter for each goroutine
// No goroutine on the first ones as we want to make sure caching is done properly
if cnt == 0 {
marshalledBody, err := json.Marshal(listItem)
if err != nil {
log.Printf("[ERROR] Schemaless: Error in marshalling of list item: %v", err)
continue
}

schemalessOutput, err := Translate(ctx, subStandard, marshalledBody, authConfig, "skip_substandard")
if err != nil {
log.Printf("[ERROR] Schemaless: Error in schemaless.Translate for sub list item: %v", err)
continue
}

parsedOutput = append(parsedOutput, schemalessOutput)
continue
}


wg.Add(1) // Increment the wait group counter for each goroutine
go func(cnt int, listItem interface{}) {
defer wg.Done() // Decrement the wait group counter when the goroutine completes

Expand All @@ -671,8 +714,8 @@ func handleSubStandard(ctx context.Context, subStandard string, returnJson strin
parsedOutput = append(parsedOutput, schemalessOutput)
}(cnt, listItem)

if cnt > 50 {
log.Printf("[WARNING] Schemaless: Breaking after 10 items in the list")
if cnt > skipAfterCount {
log.Printf("[WARNING] Schemaless: Breaking after %d items in the list", skipAfterCount)
break
}
}
Expand Down Expand Up @@ -830,9 +873,9 @@ func Translate(ctx context.Context, inputStandard string, inputValue []byte, inp
if err != nil {
log.Printf("[ERROR] Error in SaveTranslation (3): %v", err)
return []byte{}, err
}
}

log.Printf("[DEBUG] Saved translation to file. Should now run OpenAI and set cache!")
//log.Printf("[DEBUG] Saved GPT translation to file. Should now run OpenAI and set cache!")
inputStructure = []byte(gptTranslated)
}

Expand Down

0 comments on commit 72efa42

Please # to comment.