Skip to content

Commit

Permalink
Add model list and search to AI binding (#3318)
Browse files Browse the repository at this point in the history
  • Loading branch information
G4brym authored Jan 14, 2025
1 parent 1d42e0d commit 7053d6c
Show file tree
Hide file tree
Showing 25 changed files with 744 additions and 0 deletions.
58 changes: 58 additions & 0 deletions src/cloudflare/internal/ai-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,47 @@ export type AiOptions = {
sessionOptions?: SessionOptions;
};

export type AiModelsSearchParams = {
author?: string;
hide_experimental?: boolean;
page?: number;
per_page?: number;
search?: string;
source?: number;
task?: string;
};

export type AiModelsSearchObject = {
id: string;
source: number;
name: string;
description: string;
task: {
id: string;
name: string;
description: string;
};
tags: string[];
properties: {
property_id: string;
value: string;
}[];
};

export class InferenceUpstreamError extends Error {
public constructor(message: string, name = 'InferenceUpstreamError') {
super(message);
this.name = name;
}
}

export class AiInternalError extends Error {
public constructor(message: string, name = 'AiInternalError') {
super(message);
this.name = name;
}
}

export class Ai {
private readonly fetcher: Fetcher;

Expand Down Expand Up @@ -163,6 +197,30 @@ export class Ai {
}
}

public async models(
params: AiModelsSearchParams = {}
): Promise<AiModelsSearchObject[]> {
const url = new URL('https://workers-binding.ai/ai-api/models/search');

for (const [key, value] of Object.entries(params)) {
url.searchParams.set(key, value.toString());
}

const res = await this.fetcher.fetch(url, { method: 'GET' });

switch (res.status) {
case 200: {
const data = (await res.json()) as { result: AiModelsSearchObject[] };
return data.result;
}
default: {
const data = (await res.json()) as { errors: { message: string }[] };

throw new AiInternalError(data.errors[0]?.message || 'Internal Error');
}
}
}

public gateway(gatewayId: string): AiGateway {
return new AiGateway(this.fetcher, gatewayId);
}
Expand Down
62 changes: 62 additions & 0 deletions src/cloudflare/internal/test/ai/ai-api-test.js
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,67 @@ export const tests = {
requestUrl: 'https://workers-binding.ai/ai-gateway/run?version=3',
});
}

{
// Test models
const resp = await env.ai.models();

assert.deepStrictEqual(resp, [
{
id: 'f8703a00-ed54-4f98-bdc3-cd9a813286f3',
source: 1,
name: '@cf/qwen/qwen1.5-0.5b-chat',
description:
'Qwen1.5 is the improved version of Qwen, the large language model series developed by Alibaba Cloud.',
task: {
id: 'c329a1f9-323d-4e91-b2aa-582dd4188d34',
name: 'Text Generation',
description:
'Family of generative text models, such as large language models (LLM), that can be adapted for a variety of natural language tasks.',
},
tags: [],
properties: [
{
property_id: 'debug',
value: 'https://workers-binding.ai/ai-api/models/search',
},
],
},
]);
}

{
// Test models with params
const resp = await env.ai.models({
search: 'test',
per_page: 3,
page: 1,
task: 'asd',
});

assert.deepStrictEqual(resp, [
{
id: 'f8703a00-ed54-4f98-bdc3-cd9a813286f3',
source: 1,
name: '@cf/qwen/qwen1.5-0.5b-chat',
description:
'Qwen1.5 is the improved version of Qwen, the large language model series developed by Alibaba Cloud.',
task: {
id: 'c329a1f9-323d-4e91-b2aa-582dd4188d34',
name: 'Text Generation',
description:
'Family of generative text models, such as large language models (LLM), that can be adapted for a variety of natural language tasks.',
},
tags: [],
properties: [
{
property_id: 'debug',
value:
'https://workers-binding.ai/ai-api/models/search?search=test&per_page=3&page=1&task=asd',
},
],
},
]);
}
},
};
30 changes: 30 additions & 0 deletions src/cloudflare/internal/test/ai/ai-mock.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,36 @@

export default {
async fetch(request, env, ctx) {
const url = new URL(request.url);

if (url.pathname === '/ai-api/models/search') {
return Response.json({
success: true,
result: [
{
id: 'f8703a00-ed54-4f98-bdc3-cd9a813286f3',
source: 1,
name: '@cf/qwen/qwen1.5-0.5b-chat',
description:
'Qwen1.5 is the improved version of Qwen, the large language model series developed by Alibaba Cloud.',
task: {
id: 'c329a1f9-323d-4e91-b2aa-582dd4188d34',
name: 'Text Generation',
description:
'Family of generative text models, such as large language models (LLM), that can be adapted for a variety of natural language tasks.',
},
tags: [],
properties: [
{
property_id: 'debug',
value: request.url,
},
],
},
],
});
}

const data = await request.json();

const modelName = request.headers.get('cf-consn-model-id');
Expand Down
6 changes: 6 additions & 0 deletions types/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,9 @@ $ bazel test //types:all
- `src/{print,program}.ts`: helpers for printing nodes and creating programs
- `defines`: additional TypeScript-only definitions that don't correspond to
`workerd` runtime APIs, appended to the end of outputs

## Updates types

```shell
bazel build //types:types && rm -rf types/generated-snapshot && cp -r bazel-bin/types/definitions types/generated-snapshot
```
28 changes: 28 additions & 0 deletions types/defines/ai.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,33 @@ export interface AiModels {
"@cf/llava-hf/llava-1.5-7b-hf": BaseAiImageToText;
}
export type ModelListType = Record<string, any>;
export type AiModelsSearchParams = {
author?: string,
hide_experimental?: boolean
page?: number
per_page?: number
search?: string
source?: number
task?: string
}
export type AiModelsSearchObject = {
id: string,
source: number,
name: string,
description: string,
task: {
id: string,
name: string,
description: string,
},
tags: string[],
properties: {
property_id: string,
value: string,
}[],
}
export interface InferenceUpstreamError extends Error {}
export interface AiInternalError extends Error {}
export declare abstract class Ai<ModelList extends ModelListType = AiModels> {
aiGatewayLogId: string | null;
gateway(gatewayId: string): AiGateway;
Expand All @@ -288,4 +315,5 @@ export declare abstract class Ai<ModelList extends ModelListType = AiModels> {
inputs: ModelList[Name]["inputs"],
options?: AiOptions
): Promise<ModelList[Name]["postProcessedOutputs"]>;
public models(params?: AiModelsSearchParams): Promise<AiModelsSearchObject[]>;
}
28 changes: 28 additions & 0 deletions types/generated-snapshot/2021-11-03/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3683,6 +3683,33 @@ interface AiModels {
"@cf/llava-hf/llava-1.5-7b-hf": BaseAiImageToText;
}
type ModelListType = Record<string, any>;
type AiModelsSearchParams = {
author?: string;
hide_experimental?: boolean;
page?: number;
per_page?: number;
search?: string;
source?: number;
task?: string;
};
type AiModelsSearchObject = {
id: string;
source: number;
name: string;
description: string;
task: {
id: string;
name: string;
description: string;
};
tags: string[];
properties: {
property_id: string;
value: string;
}[];
};
interface InferenceUpstreamError extends Error {}
interface AiInternalError extends Error {}
declare abstract class Ai<ModelList extends ModelListType = AiModels> {
aiGatewayLogId: string | null;
gateway(gatewayId: string): AiGateway;
Expand All @@ -3691,6 +3718,7 @@ declare abstract class Ai<ModelList extends ModelListType = AiModels> {
inputs: ModelList[Name]["inputs"],
options?: AiOptions,
): Promise<ModelList[Name]["postProcessedOutputs"]>;
public models(params?: AiModelsSearchParams): Promise<AiModelsSearchObject[]>;
}
type GatewayOptions = {
id: string;
Expand Down
28 changes: 28 additions & 0 deletions types/generated-snapshot/2021-11-03/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3695,6 +3695,33 @@ export interface AiModels {
"@cf/llava-hf/llava-1.5-7b-hf": BaseAiImageToText;
}
export type ModelListType = Record<string, any>;
export type AiModelsSearchParams = {
author?: string;
hide_experimental?: boolean;
page?: number;
per_page?: number;
search?: string;
source?: number;
task?: string;
};
export type AiModelsSearchObject = {
id: string;
source: number;
name: string;
description: string;
task: {
id: string;
name: string;
description: string;
};
tags: string[];
properties: {
property_id: string;
value: string;
}[];
};
export interface InferenceUpstreamError extends Error {}
export interface AiInternalError extends Error {}
export declare abstract class Ai<ModelList extends ModelListType = AiModels> {
aiGatewayLogId: string | null;
gateway(gatewayId: string): AiGateway;
Expand All @@ -3703,6 +3730,7 @@ export declare abstract class Ai<ModelList extends ModelListType = AiModels> {
inputs: ModelList[Name]["inputs"],
options?: AiOptions,
): Promise<ModelList[Name]["postProcessedOutputs"]>;
public models(params?: AiModelsSearchParams): Promise<AiModelsSearchObject[]>;
}
export type GatewayOptions = {
id: string;
Expand Down
28 changes: 28 additions & 0 deletions types/generated-snapshot/2022-01-31/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3709,6 +3709,33 @@ interface AiModels {
"@cf/llava-hf/llava-1.5-7b-hf": BaseAiImageToText;
}
type ModelListType = Record<string, any>;
type AiModelsSearchParams = {
author?: string;
hide_experimental?: boolean;
page?: number;
per_page?: number;
search?: string;
source?: number;
task?: string;
};
type AiModelsSearchObject = {
id: string;
source: number;
name: string;
description: string;
task: {
id: string;
name: string;
description: string;
};
tags: string[];
properties: {
property_id: string;
value: string;
}[];
};
interface InferenceUpstreamError extends Error {}
interface AiInternalError extends Error {}
declare abstract class Ai<ModelList extends ModelListType = AiModels> {
aiGatewayLogId: string | null;
gateway(gatewayId: string): AiGateway;
Expand All @@ -3717,6 +3744,7 @@ declare abstract class Ai<ModelList extends ModelListType = AiModels> {
inputs: ModelList[Name]["inputs"],
options?: AiOptions,
): Promise<ModelList[Name]["postProcessedOutputs"]>;
public models(params?: AiModelsSearchParams): Promise<AiModelsSearchObject[]>;
}
type GatewayOptions = {
id: string;
Expand Down
Loading

0 comments on commit 7053d6c

Please # to comment.