-
Notifications
You must be signed in to change notification settings - Fork 532
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #204 from Portkey-AI/feat/image-models
Feat/image models
- Loading branch information
Showing
14 changed files
with
250 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import { constructConfigFromRequestHeaders, tryTargetsRecursively } from "./handlerUtils"; | ||
import { Context } from "hono"; | ||
|
||
/** | ||
* Handles the '/images/generations' API request by selecting the appropriate provider(s) and making the request to them. | ||
* | ||
* @param {Context} c - The Cloudflare Worker context. | ||
* @returns {Promise<Response>} - The response from the provider. | ||
* @throws Will throw an error if no provider options can be determined or if the request to the provider(s) fails. | ||
* @throws Will throw an 500 error if the handler fails due to some reasons | ||
*/ | ||
export async function imageGenerationsHandler(c: Context): Promise<Response> { | ||
try { | ||
let request = await c.req.json(); | ||
let requestHeaders = Object.fromEntries(c.req.raw.headers); | ||
const camelCaseConfig = constructConfigFromRequestHeaders(requestHeaders) | ||
|
||
const tryTargetsResponse = await tryTargetsRecursively( | ||
c, | ||
camelCaseConfig, | ||
request, | ||
requestHeaders, | ||
"imageGenerate", | ||
"POST", | ||
"config" | ||
); | ||
|
||
return tryTargetsResponse; | ||
} catch (err: any) { | ||
console.log("imageGenerate error", err.message); | ||
return new Response( | ||
JSON.stringify({ | ||
status: "failure", | ||
message: "Something went wrong", | ||
}), { | ||
status: 500, | ||
headers: { | ||
"content-type": "application/json" | ||
} | ||
} | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import { ImageGenerateResponse, ProviderConfig } from "../types"; | ||
|
||
export const OpenAIImageGenerateConfig: ProviderConfig = { | ||
prompt: { | ||
param: "prompt", | ||
required: true | ||
}, | ||
model: { | ||
param: "model", | ||
required: true, | ||
default: "dall-e-2" | ||
}, | ||
n: { | ||
param: "n", | ||
min: 1, | ||
max: 10 | ||
}, | ||
quality: { | ||
param: "quality" | ||
}, | ||
response_format: { | ||
param: "response_format" | ||
}, | ||
size: { | ||
param: "size" | ||
}, | ||
style: { | ||
param: "style" | ||
}, | ||
user: { | ||
param: "user" | ||
} | ||
} | ||
|
||
interface OpenAIImageObject { | ||
b64_json?: string; // The base64-encoded JSON of the generated image, if response_format is b64_json. | ||
url?: string; // The URL of the generated image, if response_format is url (default). | ||
revised_prompt?: string; // The prompt that was used to generate the image, if there was any revision to the prompt. | ||
} | ||
|
||
interface OpenAIImageGenerateResponse extends ImageGenerateResponse { | ||
data: OpenAIImageObject[] | ||
} | ||
|
||
export const OpenAIImageGenerateResponseTransform: (response: OpenAIImageGenerateResponse) => ImageGenerateResponse = (response) => response; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import { ProviderAPIConfig } from "../types"; | ||
|
||
const StabilityAIAPIConfig: ProviderAPIConfig = { | ||
baseURL: "https://api.stability.ai/v1", | ||
headers: (API_KEY:string) => { | ||
return {"Authorization": `Bearer ${API_KEY}`} | ||
}, | ||
getEndpoint: (fn:string, ENGINE_ID:string, url?: string) => { | ||
let mappedFn = fn; | ||
if (fn === "proxy" && url && url?.indexOf("text-to-image") > -1) { | ||
mappedFn = "imageGenerate" | ||
} | ||
|
||
switch(mappedFn) { | ||
case 'imageGenerate': { | ||
return `/generation/${ENGINE_ID}/text-to-image` | ||
} | ||
} | ||
} | ||
}; | ||
|
||
export default StabilityAIAPIConfig; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import { STABILITY_AI } from "../../globals"; | ||
import { ErrorResponse, ImageGenerateResponse, ProviderConfig } from "../types"; | ||
|
||
export const StabilityAIImageGenerateConfig: ProviderConfig = { | ||
prompt: { | ||
param: "text_prompts", | ||
required: true, | ||
transform: (params: any) => { | ||
return [{ | ||
text: params.prompt, | ||
weight: 1 | ||
}] | ||
} | ||
}, | ||
n: { | ||
param: "samples", | ||
min: 1, | ||
max: 10 | ||
}, | ||
size: [{ | ||
param: "height", | ||
transform: (params:any) => parseInt(params.size.toLowerCase().split('x')[1]), | ||
min: 320 | ||
}, { | ||
param: "width", | ||
transform: (params:any) => parseInt(params.size.toLowerCase().split('x')[0]), | ||
min: 320 | ||
}], | ||
style: { | ||
param: "style_preset" | ||
} | ||
} | ||
|
||
interface StabilityAIImageGenerateResponse extends ImageGenerateResponse { | ||
artifacts: ImageArtifact[]; | ||
} | ||
|
||
interface StabilityAIImageGenerateResponse extends ImageGenerateResponse { | ||
artifacts: ImageArtifact[]; | ||
} | ||
|
||
interface StabilityAIImageGenerateErrorResponse { | ||
id: string; | ||
name: string; | ||
message: string; | ||
} | ||
|
||
interface ImageArtifact { | ||
base64: string; // Image encoded in base64 | ||
finishReason: 'CONTENT_FILTERED' | 'ERROR' | 'SUCCESS'; // Enum for finish reason | ||
seed: number; // The seed associated with this image | ||
} | ||
|
||
|
||
export const StabilityAIImageGenerateResponseTransform: (response: StabilityAIImageGenerateResponse | StabilityAIImageGenerateErrorResponse, responseStatus: number) => ImageGenerateResponse | ErrorResponse = (response, responseStatus) => { | ||
if (responseStatus !== 200 && 'message' in response) { | ||
return { | ||
error: { | ||
message: response.message, | ||
type: response.name, | ||
param: null, | ||
code: null, | ||
}, | ||
provider: STABILITY_AI | ||
} | ||
} | ||
|
||
if ('artifacts' in response) { | ||
return { | ||
created: `${new Date().getTime()}`, // Corrected method call | ||
data: response.artifacts.map(art => ({b64_json: art.base64})), // Corrected object creation within map | ||
provider: STABILITY_AI | ||
}; | ||
} | ||
|
||
return { | ||
error: { | ||
message: `Invalid response recieved from ${STABILITY_AI}: ${JSON.stringify(response)}`, | ||
type: null, | ||
param: null, | ||
code: null | ||
}, | ||
provider: STABILITY_AI | ||
} as ErrorResponse; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import { ProviderConfigs } from "../types"; | ||
import StabilityAIAPIConfig from "./api"; | ||
import { StabilityAIImageGenerateConfig, StabilityAIImageGenerateResponseTransform } from "./imageGenerate"; | ||
|
||
const StabilityAIConfig: ProviderConfigs = { | ||
api: StabilityAIAPIConfig, | ||
imageGenerate: StabilityAIImageGenerateConfig, | ||
responseTransforms: { | ||
imageGenerate: StabilityAIImageGenerateResponseTransform | ||
} | ||
}; | ||
|
||
export default StabilityAIConfig; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters