Skip to content

Commit

Permalink
Parse step response when JSON schema (#10)
Browse files Browse the repository at this point in the history
* Parse step response when JSON schema

* Add tests
  • Loading branch information
neoxelox authored Dec 30, 2024
1 parent 4923c78 commit d242b71
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 12 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "promptl-ai",
"version": "0.3.3",
"version": "0.3.5",
"author": "Latitude Data",
"license": "MIT",
"description": "Compiler for PromptL, the prompt language",
Expand Down
133 changes: 133 additions & 0 deletions src/compiler/base/nodes/tags/step.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import CompileError from '$promptl/error/error'
import { complete, getExpectedError } from "$promptl/compiler/test/helpers";
import { removeCommonIndent } from "$promptl/compiler/utils";
import { Chain } from "$promptl/index";
import { describe, expect, it, vi } from "vitest";

describe("step tags", async () => {
it("does not create a variable from response if not specified", async () => {
const mock = vi.fn();
const prompt = removeCommonIndent(`
<step>
Ensure truthfulness of the following statement, give a reason and a confidence score.
Statement: fake statement
</step>
<step>
Now correct the statement if it is not true.
</step>
`);

const chain = new Chain({ prompt, parameters: { mock }});
await complete({ chain, callback: async () => `
The statement is not true because it is fake. My confidence score is 100.
`.trim()});

expect(mock).not.toHaveBeenCalled();
});

it("creates a text variable from response if specified", async () => {
const mock = vi.fn();
const prompt = removeCommonIndent(`
<step as="analysis">
Ensure truthfulness of the following statement, give a reason and a confidence score.
Statement: fake statement
</step>
<step>
{{ mock(analysis) }}
Now correct the statement if it is not true.
</step>
`);

const chain = new Chain({ prompt, parameters: { mock }});
await complete({ chain, callback: async () => `
The statement is not true because it is fake. My confidence score is 100.
`.trim()});

expect(mock).toHaveBeenCalledWith("The statement is not true because it is fake. My confidence score is 100.");
});

it("creates an object variable from response if specified and schema is provided", async () => {
const mock = vi.fn();
const prompt = removeCommonIndent(`
<step as="analysis" schema={{{type: "object", properties: {truthful: {type: "boolean"}, reason: {type: "string"}, confidence: {type: "integer"}}, required: ["truthful", "reason", "confidence"]}}}>
Ensure truthfulness of the following statement, give a reason and a confidence score.
Statement: fake statement
</step>
<step>
{{ mock(analysis) }}
{{ if !analysis.truthful && analysis.confidence > 50 }}
Correct the statement taking into account the reason: '{{ analysis.reason }}'.
{{ endif }}
</step>
`);

const chain = new Chain({ prompt, parameters: { mock }});
const { messages } = await complete({ chain, callback: async () => `
{
"truthful": false,
"reason": "It is fake",
"confidence": 100
}
`.trim()});

expect(mock).toHaveBeenCalledWith({
truthful: false,
reason: "It is fake",
confidence: 100
});
expect(messages[2]!.content).toEqual("Correct the statement taking into account the reason: 'It is fake'.");
});

it("fails creating an object variable from response if specified and schema is provided but response is invalid", async () => {
const mock = vi.fn();
const prompt = removeCommonIndent(`
<step as="analysis" schema={{{type: "object", properties: {truthful: {type: "boolean"}, reason: {type: "string"}, confidence: {type: "integer"}}, required: ["truthful", "reason", "confidence"]}}}>
Ensure truthfulness of the following statement, give a reason and a confidence score.
Statement: fake statement
</step>
<step>
{{ mock(analysis) }}
{{ if !analysis.truthful && analysis.confidence > 50 }}
Correct the statement taking into account the reason: '{{ analysis.reason }}'.
{{ endif }}
</step>
`);

const chain = new Chain({ prompt, parameters: { mock }});
const error = await getExpectedError(() => complete({ chain, callback: async () => `
Bad JSON.
`.trim()}), CompileError)
expect(error.code).toBe('function-call-error')

expect(mock).not.toHaveBeenCalled();
});

it("creates a raw variable from response if specified", async () => {
const mock = vi.fn();
const prompt = removeCommonIndent(`
<step raw="analysis">
Ensure truthfulness of the following statement, give a reason and a confidence score.
Statement: fake statement
</step>
<step>
{{ mock(analysis) }}
Now correct the statement if it is not true.
</step>
`);

const chain = new Chain({ prompt, parameters: { mock }});
await complete({ chain, callback: async () => `
The statement is not true because it is fake. My confidence score is 100.
`.trim()});

expect(mock).toHaveBeenCalledWith({
role: "assistant",
content: [
{
type: "text",
text: "The statement is not true because it is fake. My confidence score is 100.",
},
],
});
});
});
24 changes: 13 additions & 11 deletions src/compiler/base/nodes/tags/step.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export async function compile(

const stepResponse = popStepResponse()

const { as: textVarName, raw: messageVarName, ...config } = attributes
const { as: responseVarName, raw: messageVarName, ...config } = attributes

// The step must be processed.
if (stepResponse === undefined) {
Expand All @@ -57,13 +57,6 @@ export async function compile(
}

// The step has already been process, this is the continuation of the chain.
if ('as' in attributes) {
if (!tagAttributeIsLiteral(node, 'as')) {
baseNodeError(errors.invalidStaticAttribute('as'), node)
}

scope.set(String(textVarName), stepResponse)
}

if ('raw' in attributes) {
if (!tagAttributeIsLiteral(node, 'raw')) {
Expand All @@ -73,14 +66,23 @@ export async function compile(
scope.set(String(messageVarName), stepResponse)
}

// The step has already been process, this is the continuation of the chain.
if ('as' in attributes) {
if (!tagAttributeIsLiteral(node, 'as')) {
baseNodeError(errors.invalidStaticAttribute('as'), node)
}

const textVarValue = (stepResponse?.content ?? []).filter(c => c.type === ContentType.text).map(c => c.text).join('')
scope.set(String(textVarName), textVarValue)
const textResponse = (stepResponse?.content ?? []).filter(c => c.type === ContentType.text).map(c => c.text).join('')
let responseVarValue = textResponse

if ("schema" in config) {
try {
responseVarValue = JSON.parse(responseVarValue.trim())
} catch (error) {
baseNodeError(errors.functionCallError(error), node)
}
}

scope.set(String(responseVarName), responseVarValue)
}

groupContent()
Expand Down

0 comments on commit d242b71

Please # to comment.