From b7dd682c2169d747aced4af813620636c6ba1cb3 Mon Sep 17 00:00:00 2001 From: Stephen Gutekanst Date: Tue, 20 Feb 2024 00:23:18 -0700 Subject: [PATCH] add an OpenAI-compatible provider as a generic Enterprise LLM adapter Increasingly, LLM software is standardizing around the use of OpenAI-esque compatible endpoints. Some examples: * [OpenLLM](https://github.com/bentoml/OpenLLM) (commonly used to self-host/deploy various LLMs in enterprises) * [Huggingface TGI](https://github.com/huggingface/text-generation-inference/issues/735) (and, by extension, [AWS SageMaker](https://aws.amazon.com/blogs/machine-learning/announcing-the-launch-of-new-hugging-face-llm-inference-containers-on-amazon-sagemaker/)) * [Ollama](https://github.com/ollama/ollama) (commonly used for running LLMs locally, useful for local testing) All of these projects either have OpenAI-compatible API endpoints already, or are actively building out support for it. On strat we are regularly working with enterprise customers that self-host their own specific-model LLM via one of these methods, and wish for Cody to consume an OpenAI endpoint (understanding some specific model is on the other side and that Cody should optimize for / target that specific model.) Since Cody needs to tailor to a specific model (prompt generation, stop sequences, context limits, timeouts, etc.) and handle other provider-specific nuances, it is insufficient to simply expect that a customer-provided OpenAI compatible endpoint is in fact 1:1 compatible with e.g. GPT-3.5 or GPT-4. We need to be able to configure/tune many of these aspects to the specific provider/model, even though it presents as an OpenAI endpoint. In response to these needs, I am working on adding an 'OpenAI-compatible' provider proper: the ability for a Sourcegraph enterprise instance to advertise that although it is connected to an OpenAI compatible endpoint, there is in fact a specific model on the other side (starting with Starchat and Starcoder) and that Cody should target that configuration. The _first step_ of this work is this change. After this change, an existing (current-version) Sourcegraph enterprise instance can configure an OpenAI endpoint for completions via the site config such as: ``` "cody.enabled": true, "completions": { "provider": "openai", "accessToken": "asdf", "endpoint": "http://openllm.foobar.com:3000", "completionModel": "gpt-4", "chatModel": "gpt-4", "fastChatModel": "gpt-4", }, ``` The `gpt-4` model parameters will be sent to the OpenAI-compatible endpoint specified, but will otherwise be unused today. Users may then specify in their VS Code configuration that Cody should treat the LLM on the other side as if it were e.g. Starchat: ``` "cody.autocomplete.advanced.provider": "experimental-openaicompatible", "cody.autocomplete.advanced.model": "starchat-16b-beta", "cody.autocomplete.advanced.timeout.multiline": 10000, "cody.autocomplete.advanced.timeout.singleline": 10000, ``` In the future, we will make it possible to configure the above options via the Sourcegraph site configuration instead of each user needing to configure it in their VS Code settings explicitly. Signed-off-by: Stephen Gutekanst --- lib/shared/src/chat/chat.ts | 2 +- lib/shared/src/configuration.ts | 1 + .../src/sourcegraph-api/completions/client.ts | 6 +- vscode/package.json | 7 +- .../providers/create-provider.test.ts | 24 + .../completions/providers/create-provider.ts | 18 + .../completions/providers/openaicompatible.ts | 420 ++++++++++++++++++ vscode/src/configuration.ts | 5 +- 8 files changed, 477 insertions(+), 6 deletions(-) create mode 100644 vscode/src/completions/providers/openaicompatible.ts diff --git a/lib/shared/src/chat/chat.ts b/lib/shared/src/chat/chat.ts index 2532ad96fb1c..c0b6594f83f6 100644 --- a/lib/shared/src/chat/chat.ts +++ b/lib/shared/src/chat/chat.ts @@ -29,7 +29,7 @@ export class ChatClient { // HACK: The fireworks chat inference endpoints requires the last message to be from a // human. This will be the case in most of the prompts but if for some reason we have an // assistant at the end, we slice the last message for now. - params?.model?.startsWith('fireworks/') + params?.model?.startsWith('fireworks/') || params?.model?.startsWith('openaicompatible/') ? isLastMessageFromHuman ? messages : messages.slice(0, -1) diff --git a/lib/shared/src/configuration.ts b/lib/shared/src/configuration.ts index ac3890b3b9e8..32a4fbf62452 100644 --- a/lib/shared/src/configuration.ts +++ b/lib/shared/src/configuration.ts @@ -36,6 +36,7 @@ export interface Configuration { | 'anthropic' | 'fireworks' | 'unstable-openai' + | 'experimental-openaicompatible' | 'experimental-ollama' | null autocompleteAdvancedModel: string | null diff --git a/lib/shared/src/sourcegraph-api/completions/client.ts b/lib/shared/src/sourcegraph-api/completions/client.ts index 42b0e4af1e0e..84478f0fdef7 100644 --- a/lib/shared/src/sourcegraph-api/completions/client.ts +++ b/lib/shared/src/sourcegraph-api/completions/client.ts @@ -92,8 +92,12 @@ export abstract class SourcegraphCompletionsClient { params: CompletionParameters, signal?: AbortSignal ): AsyncGenerator { - // This is a technique to convert a function that takes callbacks to an async generator. + // Provide default stop sequence for starchat models. + if (!params.stopSequences && params?.model?.startsWith('openaicompatible/starchat')) { + params.stopSequences = ['<|end|>'] + } + // This is a technique to convert a function that takes callbacks to an async generator. const values: Promise[] = [] let resolve: ((value: CompletionGeneratorValue) => void) | undefined values.push( diff --git a/vscode/package.json b/vscode/package.json index eb12ccaf7bcd..273ee2690380 100644 --- a/vscode/package.json +++ b/vscode/package.json @@ -902,7 +902,7 @@ "cody.autocomplete.advanced.provider": { "type": "string", "default": null, - "enum": [null, "anthropic", "fireworks", "unstable-openai", "experimental-ollama"], + "enum": [null, "anthropic", "fireworks", "unstable-openai", "experimental-openaicompatible", "experimental-ollama"], "markdownDescription": "The provider used for code autocomplete. Most providers other than `anthropic` require the `cody.autocomplete.advanced.serverEndpoint` and `cody.autocomplete.advanced.accessToken` settings to also be set. Check the Cody output channel for error messages if autocomplete is not working as expected." }, "cody.autocomplete.advanced.serverEndpoint": { @@ -925,9 +925,10 @@ "llama-code-7b", "llama-code-13b", "llama-code-13b-instruct", - "mistral-7b-instruct-4k" + "mistral-7b-instruct-4k", + "starchat-16b-beta" ], - "markdownDescription": "Overwrite the model used for code autocompletion inference. This is only supported with the `fireworks` provider" + "markdownDescription": "Overwrite the model used for code autocompletion inference. This is only supported with the `fireworks` and 'experimental-openaicompatible' providers" }, "cody.autocomplete.completeSuggestWidgetSelection": { "type": "boolean", diff --git a/vscode/src/completions/providers/create-provider.test.ts b/vscode/src/completions/providers/create-provider.test.ts index 8ccdbedb16dc..6ee87379a3ed 100644 --- a/vscode/src/completions/providers/create-provider.test.ts +++ b/vscode/src/completions/providers/create-provider.test.ts @@ -86,6 +86,30 @@ describe('createProviderConfig', () => { expect(provider?.model).toBe('starcoder-hybrid') }) + // TODO: test 'openaicompatible' + // it('returns "fireworks" provider config and corresponding model if specified', async () => { + // const provider = await createProviderConfig( + // getVSCodeConfigurationWithAccessToken({ + // autocompleteAdvancedProvider: 'fireworks', + // autocompleteAdvancedModel: 'starcoder-7b', + // }), + // dummyCodeCompletionsClient, + // dummyAuthStatus + // ) + // expect(provider?.identifier).toBe('fireworks') + // expect(provider?.model).toBe('starcoder-7b') + // }) + + // it('returns "fireworks" provider config if specified in settings and default model', async () => { + // const provider = await createProviderConfig( + // getVSCodeConfigurationWithAccessToken({ autocompleteAdvancedProvider: 'fireworks' }), + // dummyCodeCompletionsClient, + // dummyAuthStatus + // ) + // expect(provider?.identifier).toBe('fireworks') + // expect(provider?.model).toBe('starcoder-hybrid') + // }) + it('returns "openai" provider config if specified in VSCode settings; model is ignored', async () => { const provider = await createProviderConfig( getVSCodeConfigurationWithAccessToken({ diff --git a/vscode/src/completions/providers/create-provider.ts b/vscode/src/completions/providers/create-provider.ts index 80f10a0d03bb..9b3035110529 100644 --- a/vscode/src/completions/providers/create-provider.ts +++ b/vscode/src/completions/providers/create-provider.ts @@ -12,6 +12,7 @@ import { createProviderConfig as createFireworksProviderConfig, type FireworksOptions, } from './fireworks' +import { createProviderConfig as createOpenAICompatibleProviderConfig } from './openaicompatible' import type { ProviderConfig } from './provider' import { createProviderConfig as createExperimentalOllamaProviderConfig } from './experimental-ollama' import { createProviderConfig as createUnstableOpenAIProviderConfig } from './unstable-openai' @@ -49,6 +50,15 @@ export async function createProviderConfig( case 'anthropic': { return createAnthropicProviderConfig({ client }) } + case 'experimental-openaicompatible': { + return createOpenAICompatibleProviderConfig({ + client, + model: config.autocompleteAdvancedModel ?? model ?? null, + timeouts: config.autocompleteTimeouts, + authStatus, + config, + }) + } case 'experimental-ollama': case 'unstable-ollama': { return createExperimentalOllamaProviderConfig( @@ -99,6 +109,14 @@ export async function createProviderConfig( authStatus, config, }) + case 'experimental-openaicompatible': + return createOpenAICompatibleProviderConfig({ + client, + timeouts: config.autocompleteTimeouts, + model: model ?? null, + authStatus, + config, + }) case 'aws-bedrock': case 'anthropic': return createAnthropicProviderConfig({ diff --git a/vscode/src/completions/providers/openaicompatible.ts b/vscode/src/completions/providers/openaicompatible.ts new file mode 100644 index 000000000000..37197b609a77 --- /dev/null +++ b/vscode/src/completions/providers/openaicompatible.ts @@ -0,0 +1,420 @@ +import * as vscode from 'vscode' + +import { + displayPath, + tokensToChars, + type AutocompleteTimeouts, + type CodeCompletionsClient, + type CompletionResponseGenerator, + type CodeCompletionsParams, + type ConfigurationWithAccessToken, +} from '@sourcegraph/cody-shared' + +import { getLanguageConfig } from '../../tree-sitter/language' +import { CLOSING_CODE_TAG, getHeadAndTail, OPENING_CODE_TAG } from '../text-processing' +import type { ContextSnippet } from '../types' +import { forkSignal, generatorWithTimeout, zipGenerators } from '../utils' + +import type { FetchCompletionResult } from './fetch-and-process-completions' +import { + getCompletionParamsAndFetchImpl, + getLineNumberDependentCompletionParams, +} from './get-completion-params' +import { + Provider, + standardContextSizeHints, + type CompletionProviderTracer, + type ProviderConfig, + type ProviderOptions, +} from './provider' +import type { AuthStatus } from '../../chat/protocol' + +export interface OpenAICompatibleOptions { + model: OpenAICompatibleModel + maxContextTokens?: number + client: CodeCompletionsClient + timeouts: AutocompleteTimeouts + config: Pick + authStatus: Pick +} + +const PROVIDER_IDENTIFIER = 'openaicompatible' + +const EOT_STARCHAT = '<|end|>' +const EOT_STARCODER = '<|endoftext|>' +const EOT_LLAMA_CODE = ' ' + +// Model identifiers (we are the source/definition for these in case of the openaicompatible provider.) +const MODEL_MAP = { + starchat: 'openaicompatible/starchat-16b-beta', + 'starchat-16b-beta': 'openaicompatible/starchat-16b-beta', + + starcoder: 'openaicompatible/starcoder', + 'starcoder-16b': 'openaicompatible/starcoder-16b', + 'starcoder-7b': 'openaicompatible/starcoder-7b', + 'llama-code-7b': 'openaicompatible/llama-code-7b', + 'llama-code-13b': 'openaicompatible/llama-code-13b', + 'llama-code-13b-instruct': 'openaicompatible/llama-code-13b-instruct', + 'mistral-7b-instruct-4k': 'openaicompatible/mistral-7b-instruct-4k', +} + +type OpenAICompatibleModel = + | keyof typeof MODEL_MAP + // `starcoder-hybrid` uses the 16b model for multiline requests and the 7b model for single line + | 'starcoder-hybrid' + +function getMaxContextTokens(model: OpenAICompatibleModel): number { + switch (model) { + case 'starchat': + case 'starchat-16b-beta': + case 'starcoder': + case 'starcoder-hybrid': + case 'starcoder-16b': + case 'starcoder-7b': { + // StarCoder and StarChat support up to 8k tokens, we limit to ~6k so we do not hit token limits. + return 8192 - 2048 + } + case 'llama-code-7b': + case 'llama-code-13b': + case 'llama-code-13b-instruct': + // Llama Code was trained on 16k context windows, we're constraining it here to better + return 16384 - 2048 + case 'mistral-7b-instruct-4k': + return 4096 - 2048 + default: + return 1200 + } +} + +const MAX_RESPONSE_TOKENS = 256 + +const lineNumberDependentCompletionParams = getLineNumberDependentCompletionParams({ + singlelineStopSequences: ['\n'], + multilineStopSequences: ['\n\n', '\n\r\n'], +}) + +class OpenAICompatibleProvider extends Provider { + private model: OpenAICompatibleModel + private promptChars: number + private client: CodeCompletionsClient + private timeouts?: AutocompleteTimeouts + + constructor( + options: ProviderOptions, + { + model, + maxContextTokens, + client, + timeouts, + }: Required + ) { + super(options) + this.timeouts = timeouts + this.model = model + this.promptChars = tokensToChars(maxContextTokens - MAX_RESPONSE_TOKENS) + this.client = client + } + + private createPrompt(snippets: ContextSnippet[]): string { + const { prefix, suffix } = this.options.docContext + + const intro: string[] = [] + let prompt = '' + + const languageConfig = getLanguageConfig(this.options.document.languageId) + + // In StarCoder we have a special token to announce the path of the file + if (!isStarCoderFamily(this.model)) { + intro.push(`Path: ${this.options.document.fileName}`) + } + + for (let snippetsToInclude = 0; snippetsToInclude < snippets.length + 1; snippetsToInclude++) { + if (snippetsToInclude > 0) { + const snippet = snippets[snippetsToInclude - 1] + if ('symbol' in snippet && snippet.symbol !== '') { + intro.push( + `Additional documentation for \`${snippet.symbol}\`:\n\n${snippet.content}` + ) + } else { + intro.push( + `Here is a reference snippet of code from ${displayPath(snippet.uri)}:\n\n${ + snippet.content + }` + ) + } + } + + const introString = `${intro + .join('\n\n') + .split('\n') + .map(line => (languageConfig ? languageConfig.commentStart + line : '// ')) + .join('\n')}\n` + + const suffixAfterFirstNewline = getSuffixAfterFirstNewline(suffix) + + const nextPrompt = this.createInfillingPrompt( + vscode.workspace.asRelativePath(this.options.document.fileName), + introString, + prefix, + suffixAfterFirstNewline + ) + + if (nextPrompt.length >= this.promptChars) { + return prompt + } + + prompt = nextPrompt + } + + return prompt + } + + public generateCompletions( + abortSignal: AbortSignal, + snippets: ContextSnippet[], + tracer?: CompletionProviderTracer + ): AsyncGenerator { + const { partialRequestParams, fetchAndProcessCompletionsImpl } = getCompletionParamsAndFetchImpl( + { + providerOptions: this.options, + timeouts: this.timeouts, + lineNumberDependentCompletionParams, + } + ) + + // starchat: Only use infill if the suffix is not empty + const useInfill = this.options.docContext.suffix.trim().length > 0 + const promptProps: Prompt = { + snippets: [], + uri: this.options.document.uri, + prefix: this.options.docContext.prefix, + suffix: this.options.docContext.suffix, + languageId: this.options.document.languageId, + } + + const prompt = this.model.startsWith('starchat') + ? promptString(promptProps, useInfill, this.model) + : this.createPrompt(snippets) + + const { multiline } = this.options + const requestParams: CodeCompletionsParams = { + ...partialRequestParams, + messages: [{ speaker: 'human', text: prompt }], + temperature: 0.2, + topK: 0, + model: + this.model === 'starcoder-hybrid' + ? MODEL_MAP[multiline ? 'starcoder-16b' : 'starcoder-7b'] + : this.model.startsWith('starchat') + ? '' // starchat is not a supported backend model yet, use the default server-chosen model. + : MODEL_MAP[this.model], + } + + tracer?.params(requestParams) + + const completionsGenerators = Array.from({ length: this.options.n }).map(() => { + const abortController = forkSignal(abortSignal) + + const completionResponseGenerator = generatorWithTimeout( + this.createDefaultClient(requestParams, abortController), + requestParams.timeoutMs, + abortController + ) + + return fetchAndProcessCompletionsImpl({ + completionResponseGenerator, + abortController, + providerSpecificPostProcess: this.postProcess, + providerOptions: this.options, + }) + }) + + /** + * This implementation waits for all generators to yield values + * before passing them to the consumer (request-manager). While this may appear + * as a performance bottleneck, it's necessary for the current design. + * + * The consumer operates on promises, allowing only a single resolve call + * from `requestManager.request`. Therefore, we must wait for the initial + * batch of completions before returning them collectively, ensuring all + * are included as suggested completions. + * + * To circumvent this performance issue, a method for adding completions to + * the existing suggestion list is needed. Presently, this feature is not + * available, and the switch to async generators maintains the same behavior + * as with promises. + */ + return zipGenerators(completionsGenerators) + } + + private createInfillingPrompt( + filename: string, + intro: string, + prefix: string, + suffix: string + ): string { + if (isStarCoderFamily(this.model) || isStarChatFamily(this.model)) { + // c.f. https://huggingface.co/bigcode/starcoder#fill-in-the-middle + // c.f. https://arxiv.org/pdf/2305.06161.pdf + return `${filename}${intro}${prefix}${suffix}` + } + if (isLlamaCode(this.model)) { + // c.f. https://github.com/facebookresearch/codellama/blob/main/llama/generation.py#L402 + return `
 ${intro}${prefix} ${suffix} `
+        }
+        if (this.model === 'mistral-7b-instruct-4k') {
+            // This part is copied from the anthropic prompt but fitted into the Mistral instruction format
+            const relativeFilePath = vscode.workspace.asRelativePath(this.options.document.fileName)
+            const { head, tail } = getHeadAndTail(this.options.docContext.prefix)
+            const infillSuffix = this.options.docContext.suffix
+            const infillBlock = tail.trimmed.endsWith('{\n') ? tail.trimmed.trimEnd() : tail.trimmed
+            const infillPrefix = head.raw
+            return `[INST] Below is the code from file path ${relativeFilePath}. Review the code outside the XML tags to detect the functionality, formats, style, patterns, and logics in use. Then, use what you detect and reuse methods/libraries to complete and enclose completed code only inside XML tags precisely without duplicating existing implementations. Here is the code:
+\`\`\`
+${intro}${infillPrefix}${OPENING_CODE_TAG}${CLOSING_CODE_TAG}${infillSuffix}
+\`\`\`[/INST]
+ ${OPENING_CODE_TAG}${infillBlock}`
+        }
+
+        console.error('Could not generate infilling prompt for', this.model)
+        return `${intro}${prefix}`
+    }
+
+    private postProcess = (content: string): string => {
+        if (isStarCoderFamily(this.model)) {
+            return content.replace(EOT_STARCODER, '')
+        }
+        if (isStarChatFamily(this.model)) {
+            return content.replace(EOT_STARCHAT, '')
+        }
+        if (isLlamaCode(this.model)) {
+            return content.replace(EOT_LLAMA_CODE, '')
+        }
+        return content
+    }
+
+    private createDefaultClient(
+        requestParams: CodeCompletionsParams,
+        abortController: AbortController
+    ): CompletionResponseGenerator {
+        return this.client.complete(requestParams, abortController)
+    }
+}
+
+export function createProviderConfig({
+    model,
+    timeouts,
+    ...otherOptions
+}: Omit & {
+    model: string | null
+}): ProviderConfig {
+    const resolvedModel =
+        model === null || model === ''
+            ? 'starcoder-hybrid'
+            : model === 'starcoder-hybrid'
+              ? 'starcoder-hybrid'
+              : Object.prototype.hasOwnProperty.call(MODEL_MAP, model)
+                  ? (model as keyof typeof MODEL_MAP)
+                  : null
+
+    if (resolvedModel === null) {
+        throw new Error(`Unknown model: \`${model}\``)
+    }
+
+    const maxContextTokens = getMaxContextTokens(resolvedModel)
+
+    return {
+        create(options: ProviderOptions) {
+            return new OpenAICompatibleProvider(
+                {
+                    ...options,
+                    id: PROVIDER_IDENTIFIER,
+                },
+                {
+                    model: resolvedModel,
+                    maxContextTokens,
+                    timeouts,
+                    ...otherOptions,
+                }
+            )
+        },
+        contextSizeHints: standardContextSizeHints(maxContextTokens),
+        identifier: PROVIDER_IDENTIFIER,
+        model: resolvedModel,
+    }
+}
+
+// We want to remove the same line suffix from a completion request since both StarCoder and Llama
+// code can't handle this correctly.
+function getSuffixAfterFirstNewline(suffix: string): string {
+    const firstNlInSuffix = suffix.indexOf('\n')
+
+    // When there is no next line, the suffix should be empty
+    if (firstNlInSuffix === -1) {
+        return ''
+    }
+
+    return suffix.slice(suffix.indexOf('\n'))
+}
+
+function isStarChatFamily(model: string): boolean {
+    return model.startsWith('starchat')
+}
+
+function isStarCoderFamily(model: string): boolean {
+    return model.startsWith('starcoder')
+}
+
+function isLlamaCode(model: string): boolean {
+    return model.startsWith('llama-code')
+}
+
+interface Prompt {
+    snippets: { uri: vscode.Uri; content: string }[]
+
+    uri: vscode.Uri
+    prefix: string
+    suffix: string
+
+    languageId: string
+}
+
+function fileNameLine(uri: vscode.Uri, commentStart: string): string {
+    return `${commentStart} Path: ${displayPath(uri)}\n`
+}
+
+function promptString(prompt: Prompt, infill: boolean, model: string): string {
+    const config = getLanguageConfig(prompt.languageId)
+    const commentStart = config?.commentStart || '//'
+
+    const context = prompt.snippets
+        .map(
+            ({ uri, content }) =>
+                fileNameLine(uri, commentStart) +
+                content
+                    .split('\n')
+                    .map(line => `${commentStart} ${line}`)
+                    .join('\n')
+        )
+        .join('\n\n')
+
+    const currentFileNameComment = fileNameLine(prompt.uri, commentStart)
+
+    if (model.startsWith('codellama:') && infill) {
+        const infillPrefix = context + currentFileNameComment + prompt.prefix
+
+        /**
+         * The infilll prompt for Code Llama.
+         * Source: https://github.com/facebookresearch/codellama/blob/e66609cfbd73503ef25e597fd82c59084836155d/llama/generation.py#L418
+         *
+         * Why are there spaces left and right?
+         * > For instance, the model expects this format: `
 {pre} {suf} `.
+         * But you won’t get infilling if the last space isn’t added such as in `
 {pre} {suf}`
+         *
+         * Source: https://blog.fireworks.ai/simplifying-code-infilling-with-code-llama-and-fireworks-ai-92c9bb06e29c
+         */
+        return `
 ${infillPrefix} ${prompt.suffix} `
+    }
+
+    return context + currentFileNameComment + prompt.prefix
+}
diff --git a/vscode/src/configuration.ts b/vscode/src/configuration.ts
index a4dc0727d9b2..7ffe98f312b4 100644
--- a/vscode/src/configuration.ts
+++ b/vscode/src/configuration.ts
@@ -51,7 +51,10 @@ export function getConfiguration(
     }
 
     let autocompleteAdvancedProvider = config.get<
-        Configuration['autocompleteAdvancedProvider'] | 'unstable-ollama' | 'unstable-fireworks'
+        | Configuration['autocompleteAdvancedProvider']
+        | 'unstable-ollama'
+        | 'unstable-fireworks'
+        | 'experimental-openaicompatible'
     >(CONFIG_KEY.autocompleteAdvancedProvider, null)
 
     // Handle deprecated provider identifiers