Skip to content

Commit

Permalink
Merge pull request #204 from Portkey-AI/feat/image-models
Browse files Browse the repository at this point in the history
Feat/image models
  • Loading branch information
VisargD authored Feb 13, 2024
2 parents 82ab583 + 5a7c1f3 commit 4d3425f
Show file tree
Hide file tree
Showing 14 changed files with 250 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export const GOOGLE: string = "google";
export const PERPLEXITY_AI: string = "perplexity-ai";
export const MISTRAL_AI: string = "mistral-ai";
export const DEEPINFRA: string = "deepinfra";
export const STABILITY_AI: string = "stability-ai";
export const NOMIC: string = "nomic";

export const providersWithStreamingSupport = [OPEN_AI, AZURE_OPEN_AI, ANTHROPIC, COHERE];
Expand Down
10 changes: 9 additions & 1 deletion src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Context } from "hono";
import { AZURE_OPEN_AI, CONTENT_TYPES, GOOGLE, HEADER_KEYS, PALM, POWERED_BY, RESPONSE_HEADER_KEYS, RETRY_STATUS_CODES } from "../globals";
import { AZURE_OPEN_AI, CONTENT_TYPES, GOOGLE, HEADER_KEYS, PALM, POWERED_BY, RESPONSE_HEADER_KEYS, RETRY_STATUS_CODES, STABILITY_AI } from "../globals";
import Providers from "../providers";
import { ProviderAPIConfig, endpointStrings } from "../providers/types";
import transformToProviderRequest from "../services/transformToProviderRequest";
Expand Down Expand Up @@ -188,6 +188,10 @@ export async function tryPostProxy(c: Context, providerOption:Options, inputPara
fetchOptions = constructRequest(apiConfig.headers(), provider);
baseUrl = apiConfig.baseURL;
endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, params.model, params.stream);
} else if (provider === STABILITY_AI && apiConfig.baseURL && apiConfig.getEndpoint) {
fetchOptions = constructRequest(apiConfig.headers(), provider);
baseUrl = apiConfig.baseURL;
endpoint = apiConfig.getEndpoint(fn, params.model, url);
} else {
// Construct the base object for the request
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, method);
Expand Down Expand Up @@ -332,6 +336,10 @@ export async function tryPost(c: Context, providerOption:Options, inputParams: P
fetchOptions = constructRequest(apiConfig.headers(), provider);
baseUrl = apiConfig.baseURL;
endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, transformedRequestBody.model, params.stream);
} else if (provider === STABILITY_AI && apiConfig.baseURL && apiConfig.getEndpoint) {
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider);
baseUrl = apiConfig.baseURL;
endpoint = apiConfig.getEndpoint(fn, params.model);
} else {
// Construct the base object for the POST request
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider);
Expand Down
43 changes: 43 additions & 0 deletions src/handlers/imageGenerationsHandler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { constructConfigFromRequestHeaders, tryTargetsRecursively } from "./handlerUtils";
import { Context } from "hono";

/**
* Handles the '/images/generations' API request by selecting the appropriate provider(s) and making the request to them.
*
* @param {Context} c - The Cloudflare Worker context.
* @returns {Promise<Response>} - The response from the provider.
* @throws Will throw an error if no provider options can be determined or if the request to the provider(s) fails.
* @throws Will throw an 500 error if the handler fails due to some reasons
*/
export async function imageGenerationsHandler(c: Context): Promise<Response> {
try {
let request = await c.req.json();
let requestHeaders = Object.fromEntries(c.req.raw.headers);
const camelCaseConfig = constructConfigFromRequestHeaders(requestHeaders)

const tryTargetsResponse = await tryTargetsRecursively(
c,
camelCaseConfig,
request,
requestHeaders,
"imageGenerate",
"POST",
"config"
);

return tryTargetsResponse;
} catch (err: any) {
console.log("imageGenerate error", err.message);
return new Response(
JSON.stringify({
status: "failure",
message: "Something went wrong",
}), {
status: 500,
headers: {
"content-type": "application/json"
}
}
);
}
}
7 changes: 7 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import { embeddingsHandler } from "./handlers/embeddingsHandler";
import { requestValidator } from "./middlewares/requestValidator";
import { compress } from "hono/compress";
import { getRuntimeKey } from "hono/adapter";
import { imageGenerationsHandler } from "./handlers/imageGenerationsHandler";

// Create a new Hono server instance
const app = new Hono();
Expand Down Expand Up @@ -102,6 +103,12 @@ app.post("/v1/completions", requestValidator, completionsHandler);
*/
app.post("/v1/embeddings", requestValidator, embeddingsHandler);

/**
* POST route for '/v1/images/generations'.
* Handles requests by passing them to the imageGenerations handler.
*/
app.post("/v1/images/generations", requestValidator, imageGenerationsHandler);

/**
* POST route for '/v1/prompts/:id/completions'.
* Handles portkey prompt completions route
Expand Down
3 changes: 2 additions & 1 deletion src/middlewares/requestValidator/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
POWERED_BY,
TOGETHER_AI,
DEEPINFRA,
STABILITY_AI,
NOMIC,
} from "../../globals";
import { configSchema } from "./schema/config";
Expand Down Expand Up @@ -62,7 +63,7 @@ export const requestValidator = (c: Context, next: any) => {
}
if (
requestHeaders[`x-${POWERED_BY}-provider`] &&
![OPEN_AI, AZURE_OPEN_AI, COHERE, ANTHROPIC, ANYSCALE, PALM, TOGETHER_AI, GOOGLE, MISTRAL_AI, PERPLEXITY_AI, DEEPINFRA, NOMIC].includes(
![OPEN_AI, AZURE_OPEN_AI, COHERE, ANTHROPIC, ANYSCALE, PALM, TOGETHER_AI, GOOGLE, MISTRAL_AI, PERPLEXITY_AI, DEEPINFRA, NOMIC, STABILITY_AI].includes(
requestHeaders[`x-${POWERED_BY}-provider`]
)
) {
Expand Down
4 changes: 3 additions & 1 deletion src/middlewares/requestValidator/schema/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
TOGETHER_AI,
DEEPINFRA,
NOMIC,
STABILITY_AI,
} from "../../../globals";

export const configSchema: any = z
Expand Down Expand Up @@ -47,7 +48,8 @@ export const configSchema: any = z
PERPLEXITY_AI,
MISTRAL_AI,
DEEPINFRA,
NOMIC
NOMIC,
STABILITY_AI
].includes(value),
{
message:
Expand Down
2 changes: 2 additions & 0 deletions src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import OpenAIConfig from "./openai";
import PalmAIConfig from "./palm";
import PerplexityAIConfig from "./perplexity-ai";
import TogetherAIConfig from "./together-ai";
import StabilityAIConfig from "./stability-ai";
import { ProviderConfigs } from "./types";

const Providers: { [key: string]: ProviderConfigs } = {
Expand All @@ -24,6 +25,7 @@ const Providers: { [key: string]: ProviderConfigs } = {
'perplexity-ai': PerplexityAIConfig,
'mistral-ai': MistralAIConfig,
'deepinfra': DeepInfraConfig,
'stability-ai': StabilityAIConfig,
nomic: NomicConfig
};

Expand Down
3 changes: 2 additions & 1 deletion src/providers/openai/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ const OpenAIAPIConfig: ProviderAPIConfig = {
},
complete: "/completions",
chatComplete: "/chat/completions",
embed: "/embeddings"
embed: "/embeddings",
imageGenerate: "/images/generations"
};

export default OpenAIAPIConfig;
45 changes: 45 additions & 0 deletions src/providers/openai/imageGenerate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { ImageGenerateResponse, ProviderConfig } from "../types";

export const OpenAIImageGenerateConfig: ProviderConfig = {
prompt: {
param: "prompt",
required: true
},
model: {
param: "model",
required: true,
default: "dall-e-2"
},
n: {
param: "n",
min: 1,
max: 10
},
quality: {
param: "quality"
},
response_format: {
param: "response_format"
},
size: {
param: "size"
},
style: {
param: "style"
},
user: {
param: "user"
}
}

interface OpenAIImageObject {
b64_json?: string; // The base64-encoded JSON of the generated image, if response_format is b64_json.
url?: string; // The URL of the generated image, if response_format is url (default).
revised_prompt?: string; // The prompt that was used to generate the image, if there was any revision to the prompt.
}

interface OpenAIImageGenerateResponse extends ImageGenerateResponse {
data: OpenAIImageObject[]
}

export const OpenAIImageGenerateResponseTransform: (response: OpenAIImageGenerateResponse) => ImageGenerateResponse = (response) => response;
5 changes: 4 additions & 1 deletion src/providers/openai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@ import { OpenAICompleteConfig, OpenAICompleteResponseTransform } from "./complet
import { OpenAIEmbedConfig, OpenAIEmbedResponseTransform } from "./embed";
import OpenAIAPIConfig from "./api";
import { OpenAIChatCompleteConfig, OpenAIChatCompleteResponseTransform } from "./chatComplete";
import { OpenAIImageGenerateConfig, OpenAIImageGenerateResponseTransform } from "./imageGenerate";

const OpenAIConfig: ProviderConfigs = {
complete: OpenAICompleteConfig,
embed: OpenAIEmbedConfig,
api: OpenAIAPIConfig,
chatComplete: OpenAIChatCompleteConfig,
imageGenerate: OpenAIImageGenerateConfig,
responseTransforms: {
complete: OpenAICompleteResponseTransform,
// 'stream-complete': OpenAICompleteResponseTransform,
chatComplete: OpenAIChatCompleteResponseTransform,
// 'stream-chatComplete': OpenAIChatCompleteResponseTransform,
embed: OpenAIEmbedResponseTransform
embed: OpenAIEmbedResponseTransform,
imageGenerate: OpenAIImageGenerateResponseTransform
}
};

Expand Down
22 changes: 22 additions & 0 deletions src/providers/stability-ai/api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import { ProviderAPIConfig } from "../types";

const StabilityAIAPIConfig: ProviderAPIConfig = {
baseURL: "https://api.stability.ai/v1",
headers: (API_KEY:string) => {
return {"Authorization": `Bearer ${API_KEY}`}
},
getEndpoint: (fn:string, ENGINE_ID:string, url?: string) => {
let mappedFn = fn;
if (fn === "proxy" && url && url?.indexOf("text-to-image") > -1) {
mappedFn = "imageGenerate"
}

switch(mappedFn) {
case 'imageGenerate': {
return `/generation/${ENGINE_ID}/text-to-image`
}
}
}
};

export default StabilityAIAPIConfig;
85 changes: 85 additions & 0 deletions src/providers/stability-ai/imageGenerate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import { STABILITY_AI } from "../../globals";
import { ErrorResponse, ImageGenerateResponse, ProviderConfig } from "../types";

export const StabilityAIImageGenerateConfig: ProviderConfig = {
prompt: {
param: "text_prompts",
required: true,
transform: (params: any) => {
return [{
text: params.prompt,
weight: 1
}]
}
},
n: {
param: "samples",
min: 1,
max: 10
},
size: [{
param: "height",
transform: (params:any) => parseInt(params.size.toLowerCase().split('x')[1]),
min: 320
}, {
param: "width",
transform: (params:any) => parseInt(params.size.toLowerCase().split('x')[0]),
min: 320
}],
style: {
param: "style_preset"
}
}

interface StabilityAIImageGenerateResponse extends ImageGenerateResponse {
artifacts: ImageArtifact[];
}

interface StabilityAIImageGenerateResponse extends ImageGenerateResponse {
artifacts: ImageArtifact[];
}

interface StabilityAIImageGenerateErrorResponse {
id: string;
name: string;
message: string;
}

interface ImageArtifact {
base64: string; // Image encoded in base64
finishReason: 'CONTENT_FILTERED' | 'ERROR' | 'SUCCESS'; // Enum for finish reason
seed: number; // The seed associated with this image
}


export const StabilityAIImageGenerateResponseTransform: (response: StabilityAIImageGenerateResponse | StabilityAIImageGenerateErrorResponse, responseStatus: number) => ImageGenerateResponse | ErrorResponse = (response, responseStatus) => {
if (responseStatus !== 200 && 'message' in response) {
return {
error: {
message: response.message,
type: response.name,
param: null,
code: null,
},
provider: STABILITY_AI
}
}

if ('artifacts' in response) {
return {
created: `${new Date().getTime()}`, // Corrected method call
data: response.artifacts.map(art => ({b64_json: art.base64})), // Corrected object creation within map
provider: STABILITY_AI
};
}

return {
error: {
message: `Invalid response recieved from ${STABILITY_AI}: ${JSON.stringify(response)}`,
type: null,
param: null,
code: null
},
provider: STABILITY_AI
} as ErrorResponse;
};
13 changes: 13 additions & 0 deletions src/providers/stability-ai/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import { ProviderConfigs } from "../types";
import StabilityAIAPIConfig from "./api";
import { StabilityAIImageGenerateConfig, StabilityAIImageGenerateResponseTransform } from "./imageGenerate";

const StabilityAIConfig: ProviderConfigs = {
api: StabilityAIAPIConfig,
imageGenerate: StabilityAIImageGenerateConfig,
responseTransforms: {
imageGenerate: StabilityAIImageGenerateResponseTransform
}
};

export default StabilityAIConfig;
13 changes: 12 additions & 1 deletion src/providers/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ export interface ProviderAPIConfig {
getEndpoint?: Function;
/** The endpoint for the 'stream-chatComplete' function. */
proxy?: string;
/** The endpoint for 'imageGenerate' function */
imageGenerate?: string
}

export type endpointStrings = 'complete' | 'chatComplete' | 'embed' | 'rerank' | 'moderate' | 'stream-complete' | 'stream-chatComplete' | 'proxy'
export type endpointStrings = 'complete' | 'chatComplete' | 'embed' | 'rerank' | 'moderate' | 'stream-complete' | 'stream-chatComplete' | 'proxy' | 'imageGenerate'

/**
* A collection of API configurations for multiple AI providers.
Expand Down Expand Up @@ -143,3 +145,12 @@ export interface ErrorResponse {
provider: string
}

/**
* The structure of a image generation response
* @interface
*/
export interface ImageGenerateResponse {
created: string,
data: object[],
provider: string;
}

0 comments on commit 4d3425f

Please # to comment.