-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
feat(AI Transform Node): Reduce payload size #11965
Changes from all commits
260db64
f784d00
a745b80
d129a63
d685802
40ef2a8
f9e0c9e
f1950fc
d01af92
29f4878
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,134 @@ export function getSchemas() { | |
}; | ||
} | ||
|
||
//------ Reduce payload ------ | ||
|
||
const estimateNumberOfTokens = (item: unknown, averageTokenLength: number): number => { | ||
if (typeof item === 'object') { | ||
return Math.ceil(JSON.stringify(item).length / averageTokenLength); | ||
} | ||
|
||
return 0; | ||
}; | ||
|
||
const calculateRemainingTokens = (error: Error) => { | ||
// Expected message format: | ||
//'This model's maximum context length is 8192 tokens. However, your messages resulted in 10514 tokens.' | ||
const tokens = error.message.match(/\d+/g); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are assuming the error message contains two integer numbers. Could we have a comment telling what the expected error message would look like? Also, is there a chance that some other error would be thrown that could contain 2 numbers? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated as suggested, regarding other messages - we are checking message to contain specific subsisting 'maximum context length' |
||
|
||
if (!tokens || tokens.length < 2) throw error; | ||
|
||
const maxTokens = parseInt(tokens[0], 10); | ||
const currentTokens = parseInt(tokens[1], 10); | ||
|
||
return currentTokens - maxTokens; | ||
}; | ||
|
||
const trimParentNodesSchema = ( | ||
payload: AskAiRequest.RequestPayload, | ||
remainingTokensToReduce: number, | ||
averageTokenLength: number, | ||
) => { | ||
//check if parent nodes schema takes more tokens than available | ||
let parentNodesTokenCount = estimateNumberOfTokens(payload.context.schema, averageTokenLength); | ||
|
||
if (remainingTokensToReduce > parentNodesTokenCount) { | ||
remainingTokensToReduce -= parentNodesTokenCount; | ||
payload.context.schema = []; | ||
} | ||
|
||
//remove parent nodes not referenced in the prompt | ||
if (payload.context.schema.length) { | ||
const nodes = [...payload.context.schema]; | ||
|
||
for (let nodeIndex = 0; nodeIndex < nodes.length; nodeIndex++) { | ||
if (payload.question.includes(nodes[nodeIndex].nodeName)) continue; | ||
|
||
const nodeTokens = estimateNumberOfTokens(nodes[nodeIndex], averageTokenLength); | ||
remainingTokensToReduce -= nodeTokens; | ||
parentNodesTokenCount -= nodeTokens; | ||
payload.context.schema.splice(nodeIndex, 1); | ||
|
||
if (remainingTokensToReduce <= 0) break; | ||
} | ||
} | ||
|
||
return [remainingTokensToReduce, parentNodesTokenCount]; | ||
}; | ||
|
||
const trimInputSchemaProperties = ( | ||
payload: AskAiRequest.RequestPayload, | ||
remainingTokensToReduce: number, | ||
averageTokenLength: number, | ||
parentNodesTokenCount: number, | ||
) => { | ||
if (remainingTokensToReduce <= 0) return remainingTokensToReduce; | ||
|
||
//remove properties not referenced in the prompt from the input schema | ||
if (Array.isArray(payload.context.inputSchema.schema.value)) { | ||
const props = [...payload.context.inputSchema.schema.value]; | ||
|
||
for (let index = 0; index < props.length; index++) { | ||
const key = props[index].key; | ||
|
||
if (key && payload.question.includes(key)) continue; | ||
|
||
const propTokens = estimateNumberOfTokens(props[index], averageTokenLength); | ||
remainingTokensToReduce -= propTokens; | ||
payload.context.inputSchema.schema.value.splice(index, 1); | ||
|
||
if (remainingTokensToReduce <= 0) break; | ||
} | ||
} | ||
|
||
//if tokensToReduce is still remaining, remove all parent nodes | ||
if (remainingTokensToReduce > 0) { | ||
payload.context.schema = []; | ||
remainingTokensToReduce -= parentNodesTokenCount; | ||
} | ||
|
||
return remainingTokensToReduce; | ||
}; | ||
|
||
/** | ||
* Attempts to reduce the size of the payload to fit within token limits or throws an error if unsuccessful, | ||
* payload would be modified in place | ||
* | ||
* @param {AskAiRequest.RequestPayload} payload - The request payload to be trimmed, | ||
* 'schema' and 'inputSchema.schema' will be modified. | ||
* @param {Error} error - The error to throw if the token reduction fails. | ||
* @param {number} [averageTokenLength=4] - The average token length used for estimation. | ||
* @throws {Error} - Throws the provided error if the payload cannot be reduced sufficiently. | ||
*/ | ||
export function reducePayloadSizeOrThrow( | ||
payload: AskAiRequest.RequestPayload, | ||
error: Error, | ||
averageTokenLength = 4, | ||
) { | ||
try { | ||
let remainingTokensToReduce = calculateRemainingTokens(error); | ||
|
||
const [remaining, parentNodesTokenCount] = trimParentNodesSchema( | ||
payload, | ||
remainingTokensToReduce, | ||
averageTokenLength, | ||
); | ||
|
||
remainingTokensToReduce = remaining; | ||
|
||
remainingTokensToReduce = trimInputSchemaProperties( | ||
payload, | ||
remainingTokensToReduce, | ||
averageTokenLength, | ||
parentNodesTokenCount, | ||
); | ||
|
||
if (remainingTokensToReduce > 0) throw error; | ||
} catch (e) { | ||
throw e; | ||
} | ||
} | ||
|
||
export async function generateCodeForAiTransform(prompt: string, path: string, retries = 1) { | ||
const schemas = getSchemas(); | ||
|
||
|
@@ -83,6 +211,11 @@ export async function generateCodeForAiTransform(prompt: string, path: string, r | |
code = generatedCode; | ||
break; | ||
} catch (e) { | ||
if (e.message.includes('maximum context length')) { | ||
reducePayloadSizeOrThrow(payload, e); | ||
continue; | ||
} | ||
|
||
retries--; | ||
if (!retries) throw e; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice tests 👏