Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Feature/Add custom prompt to SQLDbChain #847

Merged
merged 1 commit into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,14 +1,34 @@
import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface'
import { SqlDatabaseChain, SqlDatabaseChainInput } from 'langchain/chains/sql_db'
import { getBaseClasses } from '../../../src/utils'
import { getBaseClasses, getInputVariables } from '../../../src/utils'
import { DataSource } from 'typeorm'
import { SqlDatabase } from 'langchain/sql_db'
import { BaseLanguageModel } from 'langchain/base_language'
import { PromptTemplate, PromptTemplateInput } from 'langchain/prompts'
import { ConsoleCallbackHandler, CustomChainHandler } from '../../../src/handler'
import { DataSourceOptions } from 'typeorm/data-source'

type DatabaseType = 'sqlite' | 'postgres' | 'mssql' | 'mysql'

const defaultPrompt = `Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the tables listed below.

{table_info}

Question: {input}`

class SqlDatabaseChain_Chains implements INode {
label: string
name: string
Expand All @@ -23,7 +43,7 @@ class SqlDatabaseChain_Chains implements INode {
constructor() {
this.label = 'Sql Database Chain'
this.name = 'sqlDatabaseChain'
this.version = 1.0
this.version = 2.0
this.type = 'SqlDatabaseChain'
this.icon = 'sqlchain.svg'
this.category = 'Chains'
Expand Down Expand Up @@ -64,6 +84,19 @@ class SqlDatabaseChain_Chains implements INode {
name: 'url',
type: 'string',
placeholder: '1270.0.0.1:5432/chinook'
},
{
label: 'Custom Prompt',
name: 'customPrompt',
type: 'string',
description:
'You can provide custom prompt to the chain. This will override the existing default prompt used. See <a target="_blank" href="https://python.langchain.com/docs/integrations/tools/sqlite#customize-prompt">guide</a>',
warning:
'Prompt must include 3 input variables: {input}, {dialect}, {table_info}. You can refer to official guide from description above',
rows: 4,
placeholder: defaultPrompt,
additionalParams: true,
optional: true
}
]
}
Expand All @@ -72,17 +105,19 @@ class SqlDatabaseChain_Chains implements INode {
const databaseType = nodeData.inputs?.database as DatabaseType
const model = nodeData.inputs?.model as BaseLanguageModel
const url = nodeData.inputs?.url
const customPrompt = nodeData.inputs?.customPrompt as string

const chain = await getSQLDBChain(databaseType, url, model)
const chain = await getSQLDBChain(databaseType, url, model, customPrompt)
return chain
}

async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
const databaseType = nodeData.inputs?.database as DatabaseType
const model = nodeData.inputs?.model as BaseLanguageModel
const url = nodeData.inputs?.url
const customPrompt = nodeData.inputs?.customPrompt as string

const chain = await getSQLDBChain(databaseType, url, model)
const chain = await getSQLDBChain(databaseType, url, model, customPrompt)
const loggerHandler = new ConsoleCallbackHandler(options.logger)

if (options.socketIO && options.socketIOClientId) {
Expand All @@ -96,7 +131,7 @@ class SqlDatabaseChain_Chains implements INode {
}
}

const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseLanguageModel) => {
const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseLanguageModel, customPrompt?: string) => {
const datasource = new DataSource(
databaseType === 'sqlite'
? {
Expand All @@ -119,6 +154,14 @@ const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseL
verbose: process.env.DEBUG === 'true' ? true : false
}

if (customPrompt) {
const options: PromptTemplateInput = {
template: customPrompt,
inputVariables: getInputVariables(customPrompt)
}
obj.prompt = new PromptTemplate(options)
}

const chain = new SqlDatabaseChain(obj)
return chain
}
Expand Down
27 changes: 20 additions & 7 deletions packages/server/marketplaces/chatflows/SQL DB Chain.json
Original file line number Diff line number Diff line change
Expand Up @@ -157,17 +157,17 @@
},
{
"width": 300,
"height": 423,
"height": 475,
"id": "sqlDatabaseChain_0",
"position": {
"x": 1229.0092429246013,
"y": 231.59431102290245
"x": 1206.5244299447634,
"y": 201.04431101230608
},
"type": "customNode",
"data": {
"id": "sqlDatabaseChain_0",
"label": "Sql Database Chain",
"version": 1,
"version": 2,
"name": "sqlDatabaseChain",
"type": "SqlDatabaseChain",
"baseClasses": ["SqlDatabaseChain", "BaseChain", "Runnable"],
Expand Down Expand Up @@ -205,6 +205,18 @@
"type": "string",
"placeholder": "1270.0.0.1:5432/chinook",
"id": "sqlDatabaseChain_0-input-url-string"
},
{
"label": "Custom Prompt",
"name": "customPrompt",
"type": "string",
"description": "You can provide custom prompt to the chain. This will override the existing default prompt used. See <a target=\"_blank\" href=\"https://python.langchain.com/docs/integrations/tools/sqlite#customize-prompt\">guide</a>",
"warning": "Prompt must include 3 input variables: {input}, {dialect}, {table_info}. You can refer to official guide from description above",
"rows": 4,
"placeholder": "Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.\n\nNever query for all the columns from a specific table, only ask for a the few relevant columns given the question.\n\nPay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n\nUse the following format:\n\nQuestion: \"Question here\"\nSQLQuery: \"SQL Query to run\"\nSQLResult: \"Result of the SQLQuery\"\nAnswer: \"Final answer here\"\n\nOnly use the tables listed below.\n\n{table_info}\n\nQuestion: {input}",
"additionalParams": true,
"optional": true,
"id": "sqlDatabaseChain_0-input-customPrompt-string"
}
],
"inputAnchors": [
Expand All @@ -218,7 +230,8 @@
"inputs": {
"model": "{{chatOpenAI_0.data.instance}}",
"database": "sqlite",
"url": ""
"url": "",
"customPrompt": ""
},
"outputAnchors": [
{
Expand All @@ -233,8 +246,8 @@
},
"selected": false,
"positionAbsolute": {
"x": 1229.0092429246013,
"y": 231.59431102290245
"x": 1206.5244299447634,
"y": 201.04431101230608
},
"dragging": false
}
Expand Down