Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat(AI Transform Node): Reduce payload size #11965

70 changes: 69 additions & 1 deletion packages/editor-ui/src/components/ButtonParameter/utils.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { generateCodeForAiTransform } from './utils';
import { generateCodeForAiTransform, reducePayloadSizeOrThrow } from './utils';
import { createPinia, setActivePinia } from 'pinia';
import { generateCodeForPrompt } from '@/api/ai';
import type { AskAiRequest } from '@/types/assistant.types';
import type { Schema } from '@/Interface';

vi.mock('./utils', async () => {
const actual = await vi.importActual('./utils');
Expand Down Expand Up @@ -86,3 +88,69 @@ describe('generateCodeForAiTransform - Retry Tests', () => {
expect(generateCodeForPrompt).toHaveBeenCalledTimes(1);
});
});

const mockPayload = () =>
({
context: {
schema: [
{ nodeName: 'node1', data: 'some data' },
{ nodeName: 'node2', data: 'other data' },
],
inputSchema: {
schema: {
value: [
{ key: 'prop1', value: 'value1' },
{ key: 'prop2', value: 'value2' },
],
},
},
},
question: 'What is node1 and prop1?',
}) as unknown as AskAiRequest.RequestPayload;

describe('reducePayloadSizeOrThrow', () => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice tests 👏

it('reduces schema size when tokens exceed the limit', () => {
const payload = mockPayload();
const error = new Error('Limit is 100 tokens, but 104 were provided');

reducePayloadSizeOrThrow(payload, error);

expect(payload.context.schema.length).toBe(1);
expect(payload.context.schema[0]).toEqual({ nodeName: 'node1', data: 'some data' });
});

it('removes unreferenced properties in input schema', () => {
const payload = mockPayload();
const error = new Error('Limit is 100 tokens, but 150 were provided');

reducePayloadSizeOrThrow(payload, error);

expect(payload.context.inputSchema.schema.value.length).toBe(1);
expect((payload.context.inputSchema.schema.value as Schema[])[0].key).toBe('prop1');
});

it('removes all parent nodes if needed', () => {
const payload = mockPayload();
const error = new Error('Limit is 100 tokens, but 150 were provided');

payload.question = '';

reducePayloadSizeOrThrow(payload, error);

expect(payload.context.schema.length).toBe(0);
});

it('throws error if tokens still exceed after reductions', () => {
const payload = mockPayload();
const error = new Error('Limit is 100 tokens, but 200 were provided');

expect(() => reducePayloadSizeOrThrow(payload, error)).toThrowError(error);
});

it('throws error if message format is invalid', () => {
const payload = mockPayload();
const error = new Error('Invalid token message format');

expect(() => reducePayloadSizeOrThrow(payload, error)).toThrowError(error);
});
});
133 changes: 133 additions & 0 deletions packages/editor-ui/src/components/ButtonParameter/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,134 @@ export function getSchemas() {
};
}

//------ Reduce payload ------

const estimateNumberOfTokens = (item: unknown, averageTokenLength: number): number => {
if (typeof item === 'object') {
return Math.ceil(JSON.stringify(item).length / averageTokenLength);
}

return 0;
};

const calculateRemainingTokens = (error: Error) => {
// Expected message format:
//'This model's maximum context length is 8192 tokens. However, your messages resulted in 10514 tokens.'
const tokens = error.message.match(/\d+/g);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are assuming the error message contains two integer numbers. Could we have a comment telling what the expected error message would look like?

Also, is there a chance that some other error would be thrown that could contain 2 numbers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated as suggested, regarding other messages - we are checking message to contain specific subsisting 'maximum context length'


if (!tokens || tokens.length < 2) throw error;

const maxTokens = parseInt(tokens[0], 10);
const currentTokens = parseInt(tokens[1], 10);

return currentTokens - maxTokens;
};

const trimParentNodesSchema = (
payload: AskAiRequest.RequestPayload,
remainingTokensToReduce: number,
averageTokenLength: number,
) => {
//check if parent nodes schema takes more tokens than available
let parentNodesTokenCount = estimateNumberOfTokens(payload.context.schema, averageTokenLength);

if (remainingTokensToReduce > parentNodesTokenCount) {
remainingTokensToReduce -= parentNodesTokenCount;
payload.context.schema = [];
}

//remove parent nodes not referenced in the prompt
if (payload.context.schema.length) {
const nodes = [...payload.context.schema];

for (let nodeIndex = 0; nodeIndex < nodes.length; nodeIndex++) {
if (payload.question.includes(nodes[nodeIndex].nodeName)) continue;

const nodeTokens = estimateNumberOfTokens(nodes[nodeIndex], averageTokenLength);
remainingTokensToReduce -= nodeTokens;
parentNodesTokenCount -= nodeTokens;
payload.context.schema.splice(nodeIndex, 1);

if (remainingTokensToReduce <= 0) break;
}
}

return [remainingTokensToReduce, parentNodesTokenCount];
};

const trimInputSchemaProperties = (
payload: AskAiRequest.RequestPayload,
remainingTokensToReduce: number,
averageTokenLength: number,
parentNodesTokenCount: number,
) => {
if (remainingTokensToReduce <= 0) return remainingTokensToReduce;

//remove properties not referenced in the prompt from the input schema
if (Array.isArray(payload.context.inputSchema.schema.value)) {
const props = [...payload.context.inputSchema.schema.value];

for (let index = 0; index < props.length; index++) {
const key = props[index].key;

if (key && payload.question.includes(key)) continue;

const propTokens = estimateNumberOfTokens(props[index], averageTokenLength);
remainingTokensToReduce -= propTokens;
payload.context.inputSchema.schema.value.splice(index, 1);

if (remainingTokensToReduce <= 0) break;
}
}

//if tokensToReduce is still remaining, remove all parent nodes
if (remainingTokensToReduce > 0) {
payload.context.schema = [];
remainingTokensToReduce -= parentNodesTokenCount;
}

return remainingTokensToReduce;
};

/**
* Attempts to reduce the size of the payload to fit within token limits or throws an error if unsuccessful,
* payload would be modified in place
*
* @param {AskAiRequest.RequestPayload} payload - The request payload to be trimmed,
* 'schema' and 'inputSchema.schema' will be modified.
* @param {Error} error - The error to throw if the token reduction fails.
* @param {number} [averageTokenLength=4] - The average token length used for estimation.
* @throws {Error} - Throws the provided error if the payload cannot be reduced sufficiently.
*/
export function reducePayloadSizeOrThrow(
payload: AskAiRequest.RequestPayload,
error: Error,
averageTokenLength = 4,
) {
try {
let remainingTokensToReduce = calculateRemainingTokens(error);

const [remaining, parentNodesTokenCount] = trimParentNodesSchema(
payload,
remainingTokensToReduce,
averageTokenLength,
);

remainingTokensToReduce = remaining;

remainingTokensToReduce = trimInputSchemaProperties(
payload,
remainingTokensToReduce,
averageTokenLength,
parentNodesTokenCount,
);

if (remainingTokensToReduce > 0) throw error;
} catch (e) {
throw e;
}
}

export async function generateCodeForAiTransform(prompt: string, path: string, retries = 1) {
const schemas = getSchemas();

Expand All @@ -83,6 +211,11 @@ export async function generateCodeForAiTransform(prompt: string, path: string, r
code = generatedCode;
break;
} catch (e) {
if (e.message.includes('maximum context length')) {
reducePayloadSizeOrThrow(payload, e);
continue;
}

retries--;
if (!retries) throw e;
}
Expand Down
Loading