From 93ae955dda44d041ec147dae1c13fae1c92b8efc Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Tue, 26 Sep 2023 08:14:37 +0000 Subject: [PATCH 01/37] Support tasks with remote callbacks --- apps/webapp/app/models/task.server.ts | 1 + ....runs.$runId.tasks.$id.callback.$secret.ts | 138 ++++++++++++++++++ .../app/routes/api.v1.runs.$runId.tasks.ts | 26 +++- .../runs/performRunExecutionV2.server.ts | 4 +- .../services/tasks/processCallbackTimeout.ts | 75 ++++++++++ apps/webapp/app/services/worker.server.ts | 14 ++ packages/core/src/schemas/api.ts | 16 ++ packages/core/src/schemas/tasks.ts | 1 + .../migration.sql | 2 + packages/database/prisma/schema.prisma | 1 + packages/trigger-sdk/src/io.ts | 49 ++++--- 11 files changed, 302 insertions(+), 25 deletions(-) create mode 100644 apps/webapp/app/routes/api.v1.runs.$runId.tasks.$id.callback.$secret.ts create mode 100644 apps/webapp/app/services/tasks/processCallbackTimeout.ts create mode 100644 packages/database/prisma/migrations/20230925174509_add_callback_url/migration.sql diff --git a/apps/webapp/app/models/task.server.ts b/apps/webapp/app/models/task.server.ts index 19951eb6f7..e8a6464901 100644 --- a/apps/webapp/app/models/task.server.ts +++ b/apps/webapp/app/models/task.server.ts @@ -23,5 +23,6 @@ export function taskWithAttemptsToServerTask(task: TaskWithAttempts): ServerTask attempts: task.attempts.length, idempotencyKey: task.idempotencyKey, operation: task.operation, + callbackUrl: task.callbackUrl, }; } diff --git a/apps/webapp/app/routes/api.v1.runs.$runId.tasks.$id.callback.$secret.ts b/apps/webapp/app/routes/api.v1.runs.$runId.tasks.$id.callback.$secret.ts new file mode 100644 index 0000000000..3925d830dd --- /dev/null +++ b/apps/webapp/app/routes/api.v1.runs.$runId.tasks.$id.callback.$secret.ts @@ -0,0 +1,138 @@ +import type { ActionArgs } from "@remix-run/server-runtime"; +import { json } from "@remix-run/server-runtime"; +import type { CallbackTaskBodyOutput } from "@trigger.dev/core"; +import { CallbackTaskBodyInputSchema } from "@trigger.dev/core"; +import { RuntimeEnvironmentType } from "@trigger.dev/database"; +import { z } from "zod"; +import { $transaction, PrismaClient, PrismaClientOrTransaction, prisma } from "~/db.server"; +import { enqueueRunExecutionV2 } from "~/models/jobRunExecution.server"; +import { logger } from "~/services/logger.server"; + +const ParamsSchema = z.object({ + runId: z.string(), + id: z.string(), + secret: z.string(), +}); + +export async function action({ request, params }: ActionArgs) { + // Ensure this is a POST request + if (request.method.toUpperCase() !== "POST") { + return { status: 405, body: "Method Not Allowed" }; + } + + const { runId, id } = ParamsSchema.parse(params); + + // Now parse the request body + const anyBody = await request.json(); + + // Allows any valid object + // TODO: maybe add proper schema parsing during io.runTask(), or even skip this step + const body = CallbackTaskBodyInputSchema.safeParse(anyBody); + + if (!body.success) { + return json({ error: "Invalid request body" }, { status: 400 }); + } + + const service = new CallbackRunTaskService(); + + try { + await service.call(runId, id, body.data, new URL(request.url).href); + + return json({ success: true }); + } catch (error) { + if (error instanceof Error) { + logger.error("Error while processing task callback:", { error }); + } + + return json({ error: "Something went wrong" }, { status: 500 }); + } +} + +export class CallbackRunTaskService { + #prismaClient: PrismaClient; + + constructor(prismaClient: PrismaClient = prisma) { + this.#prismaClient = prismaClient; + } + + public async call( + runId: string, + id: string, + taskBody: CallbackTaskBodyOutput, + callbackUrl: string + ): Promise { + const task = await findTask(prisma, id); + + if (!task) { + return; + } + + if (task.runId !== runId) { + return; + } + + if (task.status !== "WAITING") { + return; + } + + if (!task.callbackUrl) { + return; + } + + if (new URL(task.callbackUrl).pathname !== new URL(callbackUrl).pathname) { + logger.error("Callback URLs don't match", { runId, taskId: id, callbackUrl }); + return; + } + + logger.debug("CallbackRunTaskService.call()", { task }); + + await this.#resumeTask(task, taskBody); + } + + async #resumeTask(task: NonNullable, output: Record) { + await $transaction(this.#prismaClient, async (tx) => { + await tx.taskAttempt.updateMany({ + where: { + taskId: task.id, + status: "PENDING", + }, + data: { + status: "COMPLETED", + }, + }); + + await tx.task.update({ + where: { id: task.id }, + data: { + status: "COMPLETED", + completedAt: new Date(), + output: output ? output : undefined, + }, + }); + + await this.#resumeRunExecution(task, tx); + }); + } + + async #resumeRunExecution(task: NonNullable, prisma: PrismaClientOrTransaction) { + await enqueueRunExecutionV2(task.run, prisma, { + skipRetrying: task.run.environment.type === RuntimeEnvironmentType.DEVELOPMENT, + }); + } +} + +type FoundTask = Awaited>; + +async function findTask(prisma: PrismaClientOrTransaction, id: string) { + return prisma.task.findUnique({ + where: { id }, + include: { + run: { + include: { + environment: true, + queue: true, + }, + }, + }, + }); +} diff --git a/apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts b/apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts index 0e5eaecd99..1dfcf164e7 100644 --- a/apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts +++ b/apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts @@ -9,6 +9,8 @@ import { authenticateApiRequest } from "~/services/apiAuth.server"; import { logger } from "~/services/logger.server"; import { ulid } from "~/services/ulid.server"; import { workerQueue } from "~/services/worker.server"; +import { generateSecret } from "~/services/sources/utils.server"; +import { env } from "~/env.server"; const ParamsSchema = z.object({ runId: z.string(), @@ -106,10 +108,13 @@ export class RunTaskService { }, }); + const delayUntilInFuture = taskBody.delayUntil && taskBody.delayUntil.getTime() > Date.now(); + const callbackEnabled = taskBody.callback?.enabled; + if (existingTask) { if (existingTask.status === "CANCELED") { const existingTaskStatus = - (taskBody.delayUntil && taskBody.delayUntil.getTime() > Date.now()) || taskBody.trigger + delayUntilInFuture || callbackEnabled || taskBody.trigger ? "WAITING" : taskBody.noop ? "COMPLETED" @@ -154,16 +159,21 @@ export class RunTaskService { status = "CANCELED"; } else { status = - (taskBody.delayUntil && taskBody.delayUntil.getTime() > Date.now()) || taskBody.trigger + delayUntilInFuture || callbackEnabled || taskBody.trigger ? "WAITING" : taskBody.noop ? "COMPLETED" : "RUNNING"; } + const taskId = ulid(); + const callbackUrl = callbackEnabled + ? `${env.APP_ORIGIN}/api/v1/runs/${runId}/tasks/${taskId}/callback/${generateSecret()}` + : undefined; + const task = await tx.task.create({ data: { - id: ulid(), + id: taskId, idempotencyKey, displayKey: taskBody.displayKey, runConnection: taskBody.connectionKey @@ -194,6 +204,7 @@ export class RunTaskService { properties: taskBody.properties ?? undefined, redact: taskBody.redact ?? undefined, operation: taskBody.operation, + callbackUrl, style: taskBody.style ?? { style: "normal" }, attempts: { create: { @@ -217,6 +228,15 @@ export class RunTaskService { }, { tx, runAt: task.delayUntil ?? undefined } ); + } else if (task.status === "WAITING" && callbackUrl && taskBody.callback) { + // We need to schedule the callback timeout + await workerQueue.enqueue( + "processCallbackTimeout", + { + id: task.id, + }, + { tx, runAt: new Date(Date.now() + taskBody.callback.timeoutInSeconds * 1000) } + ); } return task; diff --git a/apps/webapp/app/services/runs/performRunExecutionV2.server.ts b/apps/webapp/app/services/runs/performRunExecutionV2.server.ts index 61feab0185..139a1bb05f 100644 --- a/apps/webapp/app/services/runs/performRunExecutionV2.server.ts +++ b/apps/webapp/app/services/runs/performRunExecutionV2.server.ts @@ -429,7 +429,9 @@ export class PerformRunExecutionV2Service { // If the task has an operation, then the next performRunExecution will occur // when that operation has finished - if (!data.task.operation) { + // Tasks with callbacks enabled will also get processed separately, i.e. when + // they time out, or on valid requests to their callbackUrl + if (!data.task.operation && !data.task.callbackUrl) { await enqueueRunExecutionV2(run, tx, { runAt: data.task.delayUntil ?? undefined, resumeTaskId: data.task.id, diff --git a/apps/webapp/app/services/tasks/processCallbackTimeout.ts b/apps/webapp/app/services/tasks/processCallbackTimeout.ts new file mode 100644 index 0000000000..e4528935b9 --- /dev/null +++ b/apps/webapp/app/services/tasks/processCallbackTimeout.ts @@ -0,0 +1,75 @@ +import { RuntimeEnvironmentType, type Task } from "@trigger.dev/database"; +import { $transaction, PrismaClient, PrismaClientOrTransaction, prisma } from "~/db.server"; +import { enqueueRunExecutionV2 } from "~/models/jobRunExecution.server"; +import { logger } from "../logger.server"; + +type FoundTask = Awaited>; + +export class ProcessCallbackTimeoutService { + #prismaClient: PrismaClient; + + constructor(prismaClient: PrismaClient = prisma) { + this.#prismaClient = prismaClient; + } + + public async call(id: string) { + const task = await findTask(this.#prismaClient, id); + + if (!task) { + return; + } + + if (task.status !== "WAITING" || !task.callbackUrl) { + return; + } + + logger.debug("ProcessCallbackTimeoutService.call", { task }); + + return await this.#resumeTask(task, null); + } + + async #resumeTask(task: NonNullable, output: any) { + await $transaction(this.#prismaClient, async (tx) => { + await tx.taskAttempt.updateMany({ + where: { + taskId: task.id, + status: "PENDING", + }, + data: { + status: "COMPLETED", + }, + }); + + await tx.task.update({ + where: { id: task.id }, + data: { + status: "COMPLETED", + completedAt: new Date(), + output: output ? output : undefined, + }, + }); + + await this.#resumeRunExecution(task, tx); + }); + } + + async #resumeRunExecution(task: NonNullable, prisma: PrismaClientOrTransaction) { + await enqueueRunExecutionV2(task.run, prisma, { + skipRetrying: task.run.environment.type === RuntimeEnvironmentType.DEVELOPMENT, + }); + } +} + +async function findTask(prisma: PrismaClient, id: string) { + return prisma.task.findUnique({ + where: { id }, + include: { + run: { + include: { + environment: true, + queue: true, + }, + }, + }, + }); +} diff --git a/apps/webapp/app/services/worker.server.ts b/apps/webapp/app/services/worker.server.ts index 93fc0753af..9cd8a73f36 100644 --- a/apps/webapp/app/services/worker.server.ts +++ b/apps/webapp/app/services/worker.server.ts @@ -19,6 +19,7 @@ import { DeliverScheduledEventService } from "./schedules/deliverScheduledEvent. import { ActivateSourceService } from "./sources/activateSource.server"; import { DeliverHttpSourceRequestService } from "./sources/deliverHttpSourceRequest.server"; import { PerformTaskOperationService } from "./tasks/performTaskOperation.server"; +import { ProcessCallbackTimeoutService } from "./tasks/processCallbackTimeout"; import { addMissingVersionField } from "@trigger.dev/core"; const workerCatalog = { @@ -30,6 +31,9 @@ const workerCatalog = { }), scheduleEmail: DeliverEmailSchema, startRun: z.object({ id: z.string() }), + processCallbackTimeout: z.object({ + id: z.string(), + }), performTaskOperation: z.object({ id: z.string(), }), @@ -239,6 +243,16 @@ function getWorkerQueue() { await service.call(payload.id); }, }, + processCallbackTimeout: { + priority: 0, // smaller number = higher priority + queueName: (payload) => `tasks:${payload.id}`, + maxAttempts: 3, + handler: async (payload, job) => { + const service = new ProcessCallbackTimeoutService(); + + await service.call(payload.id); + }, + }, performTaskOperation: { priority: 0, // smaller number = higher priority queueName: (payload) => `tasks:${payload.id}`, diff --git a/packages/core/src/schemas/api.ts b/packages/core/src/schemas/api.ts index b5f1f91610..127f95eace 100644 --- a/packages/core/src/schemas/api.ts +++ b/packages/core/src/schemas/api.ts @@ -608,6 +608,13 @@ export const RunTaskOptionsSchema = z.object({ params: z.any(), /** The style of the log entry. */ style: StyleSchema.optional(), + /** Allows you to return the data sent to task.callbackUrl instead of the async callback return */ + callback: z.object({ + /** Enable the callback feature */ + enabled: z.boolean(), + /** Time to wait for callback requests */ + timeoutInSeconds: z.number(), + }).partial().optional(), /** Allows you to link the Integration connection in the logs. This is handled automatically in integrations. */ connectionKey: z.string().optional(), /** An operation you want to perform on the Trigger.dev platform, current only "fetch" is supported. If you wish to `fetch` use [`io.backgroundFetch()`](https://trigger.dev/docs/sdk/io/backgroundfetch) instead. */ @@ -634,10 +641,19 @@ export type RunTaskBodyInput = z.infer; export const RunTaskBodyOutputSchema = RunTaskBodyInputSchema.extend({ params: DeserializedJsonSchema.optional().nullable(), + callback: z.object({ + enabled: z.boolean(), + timeoutInSeconds: z.number().default(3600), + }).optional(), }); export type RunTaskBodyOutput = z.infer; +export const CallbackTaskBodyInputSchema = z.object({}).passthrough(); + +export type CallbackTaskBodyInput = Prettify>; +export type CallbackTaskBodyOutput = z.infer; + export const CompleteTaskBodyInputSchema = RunTaskBodyInputSchema.pick({ properties: true, description: true, diff --git a/packages/core/src/schemas/tasks.ts b/packages/core/src/schemas/tasks.ts index fe6d43bbdb..559dc4bab3 100644 --- a/packages/core/src/schemas/tasks.ts +++ b/packages/core/src/schemas/tasks.ts @@ -31,6 +31,7 @@ export const TaskSchema = z.object({ parentId: z.string().optional().nullable(), style: StyleSchema.optional().nullable(), operation: z.string().optional().nullable(), + callbackUrl: z.string().optional().nullable(), }); export const ServerTaskSchema = TaskSchema.extend({ diff --git a/packages/database/prisma/migrations/20230925174509_add_callback_url/migration.sql b/packages/database/prisma/migrations/20230925174509_add_callback_url/migration.sql new file mode 100644 index 0000000000..4808101efc --- /dev/null +++ b/packages/database/prisma/migrations/20230925174509_add_callback_url/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "Task" ADD COLUMN "callbackUrl" TEXT; diff --git a/packages/database/prisma/schema.prisma b/packages/database/prisma/schema.prisma index 86757f86d3..344d23786e 100644 --- a/packages/database/prisma/schema.prisma +++ b/packages/database/prisma/schema.prisma @@ -795,6 +795,7 @@ model Task { redact Json? style Json? operation String? + callbackUrl String? startedAt DateTime? completedAt DateTime? diff --git a/packages/trigger-sdk/src/io.ts b/packages/trigger-sdk/src/io.ts index 9286e71090..58c8ff4338 100644 --- a/packages/trigger-sdk/src/io.ts +++ b/packages/trigger-sdk/src/io.ts @@ -526,12 +526,12 @@ export class IO { * @param onError The callback that will be called when the Task fails. The callback receives the error, the Task and the IO as parameters. If you wish to retry then return an object with a `retryAt` property. * @returns A Promise that resolves with the returned value of the callback. */ - async runTask | void>( + async runTask | void, TOptions extends RunTaskOptions = RunTaskOptions>( key: string | any[], callback: (task: ServerTask, io: IO) => Promise, - options?: RunTaskOptions, + options?: TOptions, onError?: RunTaskErrorCallback - ): Promise { + ): Promise { const parentId = this._taskStorage.getStore()?.taskId; if (parentId) { @@ -592,24 +592,6 @@ export class IO { throw new Error(task.error ?? task?.output ? JSON.stringify(task.output) : "Task errored"); } - if (task.status === "WAITING") { - this._logger.debug("Task waiting", { - idempotencyKey, - task, - }); - - throw new ResumeWithTaskError(task); - } - - if (task.status === "RUNNING" && typeof task.operation === "string") { - this._logger.debug("Task running operation", { - idempotencyKey, - task, - }); - - throw new ResumeWithTaskError(task); - } - const executeTask = async () => { try { const result = await callback(task, this); @@ -621,6 +603,9 @@ export class IO { task, }); + // TODO: empty return? maybe don't even parse first + if (task.status === "WAITING") return output; + const completedTask = await this._apiClient.completeTask(this._id, task.id, { output: output ?? undefined, properties: task.outputProperties ?? undefined, @@ -696,6 +681,28 @@ export class IO { } }; + if (task.status === "WAITING") { + this._logger.debug("Task waiting", { + idempotencyKey, + task, + }); + + if (task.callbackUrl) { + await this._taskStorage.run({ taskId: task.id }, executeTask); + } + + throw new ResumeWithTaskError(task); + } + + if (task.status === "RUNNING" && typeof task.operation === "string") { + this._logger.debug("Task running operation", { + idempotencyKey, + task, + }); + + throw new ResumeWithTaskError(task); + } + return this._taskStorage.run({ taskId: task.id }, executeTask); } From ff0863e4e41f64b4a0eb91ebc7d2c49d713b038b Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Tue, 26 Sep 2023 08:15:22 +0000 Subject: [PATCH 02/37] Add common integration tsconfig --- config-packages/tsconfig/integration.json | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 config-packages/tsconfig/integration.json diff --git a/config-packages/tsconfig/integration.json b/config-packages/tsconfig/integration.json new file mode 100644 index 0000000000..753e6091dd --- /dev/null +++ b/config-packages/tsconfig/integration.json @@ -0,0 +1,17 @@ +{ + "extends": "./node18.json", + "compilerOptions": { + "lib": ["DOM", "DOM.Iterable", "ES2019"], + "paths": { + "@trigger.dev/sdk/*": ["../../packages/trigger-sdk/src/*"], + "@trigger.dev/sdk": ["../../packages/trigger-sdk/src/index"], + "@trigger.dev/integration-kit/*": ["../../packages/integration-kit/src/*"], + "@trigger.dev/integration-kit": ["../../packages/integration-kit/src/index"] + }, + "declaration": false, + "declarationMap": false, + "baseUrl": ".", + "stripInternal": true + }, + "exclude": ["node_modules"] +} From f998afc2a96b431d597f31b33b2c2d28e06824de Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Tue, 26 Sep 2023 09:29:44 +0000 Subject: [PATCH 03/37] Add Replicate integration --- integrations/replicate/README.md | 1 + integrations/replicate/package.json | 37 +++ integrations/replicate/src/collections.ts | 35 +++ integrations/replicate/src/deployments.ts | 70 ++++++ integrations/replicate/src/index.ts | 271 ++++++++++++++++++++++ integrations/replicate/src/models.ts | 79 +++++++ integrations/replicate/src/predictions.ts | 92 ++++++++ integrations/replicate/src/trainings.ts | 104 +++++++++ integrations/replicate/src/types.ts | 1 + integrations/replicate/src/utils.ts | 47 ++++ integrations/replicate/tsconfig.json | 4 + integrations/replicate/tsup.config.ts | 22 ++ 12 files changed, 763 insertions(+) create mode 100644 integrations/replicate/README.md create mode 100644 integrations/replicate/package.json create mode 100644 integrations/replicate/src/collections.ts create mode 100644 integrations/replicate/src/deployments.ts create mode 100644 integrations/replicate/src/index.ts create mode 100644 integrations/replicate/src/models.ts create mode 100644 integrations/replicate/src/predictions.ts create mode 100644 integrations/replicate/src/trainings.ts create mode 100644 integrations/replicate/src/types.ts create mode 100644 integrations/replicate/src/utils.ts create mode 100644 integrations/replicate/tsconfig.json create mode 100644 integrations/replicate/tsup.config.ts diff --git a/integrations/replicate/README.md b/integrations/replicate/README.md new file mode 100644 index 0000000000..67a4b88e87 --- /dev/null +++ b/integrations/replicate/README.md @@ -0,0 +1 @@ +# @trigger.dev/replicate diff --git a/integrations/replicate/package.json b/integrations/replicate/package.json new file mode 100644 index 0000000000..ac7f8aecf3 --- /dev/null +++ b/integrations/replicate/package.json @@ -0,0 +1,37 @@ +{ + "name": "@trigger.dev/replicate", + "version": "2.1.4", + "description": "Trigger.dev integration for replicate", + "main": "./dist/index.js", + "types": "./dist/index.d.ts", + "publishConfig": { + "access": "public" + }, + "files": [ + "dist/index.js", + "dist/index.d.ts", + "dist/index.js.map" + ], + "devDependencies": { + "@trigger.dev/tsconfig": "workspace:*", + "@types/node": "16.x", + "rimraf": "^3.0.2", + "tsup": "7.1.x", + "typescript": "4.9.4" + }, + "scripts": { + "clean": "rimraf dist", + "build": "npm run clean && npm run build:tsup", + "build:tsup": "tsup", + "typecheck": "tsc --noEmit" + }, + "dependencies": { + "@trigger.dev/integration-kit": "workspace:^2.1.0", + "@trigger.dev/sdk": "workspace:^2.1.0", + "replicate": "^0.18.1", + "zod": "3.21.4" + }, + "engines": { + "node": ">=16.8.0" + } +} \ No newline at end of file diff --git a/integrations/replicate/src/collections.ts b/integrations/replicate/src/collections.ts new file mode 100644 index 0000000000..31a94ff557 --- /dev/null +++ b/integrations/replicate/src/collections.ts @@ -0,0 +1,35 @@ +import { IntegrationTaskKey } from "@trigger.dev/sdk"; +import { Page, Collection } from "replicate"; + +import { ReplicateRunTask } from "./index"; +import { ReplicateReturnType } from "./types"; + +export class Collections { + constructor(private runTask: ReplicateRunTask) {} + + get(key: IntegrationTaskKey, params: { slug: string }): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.collections.get(params.slug); + }, + { + name: "Get Collection", + params, + properties: [{ label: "Collection Slug", text: params.slug }], + } + ); + } + + list(key: IntegrationTaskKey): ReplicateReturnType> { + return this.runTask( + key, + (client) => { + return client.collections.list(); + }, + { + name: "List Collections", + } + ); + } +} diff --git a/integrations/replicate/src/deployments.ts b/integrations/replicate/src/deployments.ts new file mode 100644 index 0000000000..e6ebd68a58 --- /dev/null +++ b/integrations/replicate/src/deployments.ts @@ -0,0 +1,70 @@ +import { IntegrationTaskKey } from "@trigger.dev/sdk"; +import ReplicateClient, { Prediction } from "replicate"; + +import { ReplicateRunTask } from "./index"; +import { createDeploymentProperties } from "./utils"; +import { ReplicateReturnType } from "./types"; + +export class Deployments { + constructor(private runTask: ReplicateRunTask) {} + + get predictions() { + return new Predictions(this.runTask); + } +} + +class Predictions { + constructor(private runTask: ReplicateRunTask) {} + + create( + key: IntegrationTaskKey, + params: { + deployment_owner: string; + deployment_name: string; + } & Parameters[2] + ): ReplicateReturnType { + return this.runTask( + key, + (client) => { + const { deployment_owner, deployment_name, ...options } = params; + + return client.deployments.predictions.create(deployment_owner, deployment_name, options); + }, + { + name: "Create Prediction With Deployment", + params, + properties: createDeploymentProperties(params), + } + ); + } + + createAndWaitForCompletion( + key: IntegrationTaskKey, + params: { + deployment_owner: string; + deployment_name: string; + } & Omit< + Parameters[2], + "webhook" | "webhook_events_filter" + > + ): ReplicateReturnType { + return this.runTask( + key, + (client, task) => { + const { deployment_owner, deployment_name, ...options } = params; + + return client.deployments.predictions.create(deployment_owner, deployment_name, { + ...options, + webhook: task.callbackUrl ?? undefined, + webhook_events_filter: ["completed"], + }); + }, + { + name: "Create And Await Prediction With Deployment", + params, + properties: createDeploymentProperties(params), + callback: { enabled: true }, + } + ); + } +} diff --git a/integrations/replicate/src/index.ts b/integrations/replicate/src/index.ts new file mode 100644 index 0000000000..595e4bdcbf --- /dev/null +++ b/integrations/replicate/src/index.ts @@ -0,0 +1,271 @@ +import { + TriggerIntegration, + RunTaskOptions, + IO, + IOTask, + IntegrationTaskKey, + RunTaskErrorCallback, + Json, + retry, + ConnectionAuth, +} from "@trigger.dev/sdk"; +import ReplicateClient, { Page } from "replicate"; + +import { Predictions } from "./predictions"; +import { Models } from "./models"; +import { Trainings } from "./trainings"; +import { Collections } from "./collections"; +import { ReplicateReturnType } from "./types"; + +export type ReplicateIntegrationOptions = { + id: string; + apiKey: string; +}; + +export type ReplicateRunTask = InstanceType["runTask"]; + +export class Replicate implements TriggerIntegration { + private _options: ReplicateIntegrationOptions; + private _client?: any; + private _io?: IO; + private _connectionKey?: string; + + constructor(private options: ReplicateIntegrationOptions) { + if (Object.keys(options).includes("apiKey") && !options.apiKey) { + throw `Can't create Replicate integration (${options.id}) as apiKey was undefined`; + } + + this._options = options; + } + + get authSource() { + return "LOCAL" as const; + } + + get id() { + return this.options.id; + } + + get metadata() { + return { id: "replicate", name: "Replicate" }; + } + + cloneForRun(io: IO, connectionKey: string, auth?: ConnectionAuth) { + const replicate = new Replicate(this._options); + replicate._io = io; + replicate._connectionKey = connectionKey; + replicate._client = this.createClient(auth); + return replicate; + } + + createClient(auth?: ConnectionAuth) { + return new ReplicateClient({ + auth: this._options.apiKey, + }); + } + + runTask | void>( + key: IntegrationTaskKey, + callback: (client: ReplicateClient, task: IOTask, io: IO) => Promise, + options?: RunTaskOptions, + errorCallback?: RunTaskErrorCallback + ): Promise { + if (!this._io) throw new Error("No IO"); + if (!this._connectionKey) throw new Error("No connection key"); + + return this._io.runTask( + key, + (task, io) => { + if (!this._client) throw new Error("No client"); + return callback(this._client, task, io); + }, + { + icon: "replicate", + retry: retry.standardBackoff, + ...(options ?? {}), + connectionKey: this._connectionKey, + }, + errorCallback ?? onError + ); + } + + get collections() { + return new Collections(this.runTask.bind(this)); + } + + get models() { + return new Models(this.runTask.bind(this)); + } + + get predictions() { + return new Predictions(this.runTask.bind(this)); + } + + get trainings() { + return new Trainings(this.runTask.bind(this)); + } + + async *paginate( + task: (key: string) => Promise>, + key: IntegrationTaskKey, + counter: number = 0 + ): AsyncGenerator { + const boundTask = task.bind(this as any); + + const page = await boundTask(`${key}-${counter}`); + yield page.results; + + if (page.next) { + const nextStep = counter++; + + const nextPage = () => { + return this.request>(`${key}-${nextStep}`, { + route: page.next!, + options: { method: "GET" }, + }); + }; + + yield* this.paginate(nextPage, key, nextStep); + } + } + + async getAll( + task: (key: string) => Promise>, + key: IntegrationTaskKey + ): ReplicateReturnType { + let results: T[] = []; + + for await (const results of this.paginate(task, key)) { + results.concat(results); + } + + return results; + } + + request( + key: IntegrationTaskKey, + params: { + route: string | URL; + options: Parameters[1]; + } + ): ReplicateReturnType { + return this.runTask( + key, + async (client) => { + const response = await client.request(params.route, params.options); + + return response.json(); + }, + { + name: "Send Request", + params, + properties: [ + { label: "Route", text: params.route.toString() }, + ...(params.options.method ? [{ label: "Method", text: params.options.method }] : []), + ], + callback: { enabled: true }, + } + ); + } + + run( + key: IntegrationTaskKey, + params: { + identifier: Parameters[0]; + } & Omit< + Parameters[1], + "webhook" | "webhook_events_filter" | "wait" | "signal" + > + ): ReplicateReturnType { + const { identifier, ...options } = params; + + // see: https://github.com/replicate/replicate-javascript/blob/4b0d9cb0e226fab3d3d31de5b32261485acf5626/index.js#L102 + + const namePattern = /[a-zA-Z0-9]+(?:(?:[._]|__|[-]*)[a-zA-Z0-9]+)*/; + const pattern = new RegExp( + `^(?${namePattern.source})/(?${namePattern.source}):(?[0-9a-fA-F]+)$` + ); + + const match = identifier.match(pattern); + + if (!match || !match.groups) { + throw new Error('Invalid version. It must be in the format "owner/name:version"'); + } + + const { version } = match.groups; + + return this.predictions.createAndWaitForCompletion(key, { ...options, version }); + } + + // TODO: wait(prediction) - needs polling +} + +class ApiError extends Error { + constructor( + message: string, + readonly request: Request, + readonly response: Response + ) { + super(message); + this.name = "ApiError"; + } +} + +function isReplicateApiError(error: unknown): error is ApiError { + if (typeof error !== "object" || error === null) { + return false; + } + + const apiError = error as ApiError; + + return ( + apiError.name === "ApiError" && + apiError.request instanceof Request && + apiError.response instanceof Response + ); +} + +function shouldRetry(method: string, status: number) { + return status === 429 || (method === "GET" && status >= 500); +} + +export function onError(error: unknown): ReturnType { + if (!isReplicateApiError(error)) { + return; + } + + if (!shouldRetry(error.request.method, error.response.status)) { + return { + skipRetrying: true, + }; + } + + // see: https://github.com/replicate/replicate-javascript/blob/4b0d9cb0e226fab3d3d31de5b32261485acf5626/lib/util.js#L43 + + const retryAfter = error.response.headers.get("retry-after"); + + if (retryAfter) { + const resetDate = new Date(retryAfter); + + if (!Number.isNaN(resetDate.getTime())) { + return { + retryAt: resetDate, + error, + }; + } + } + + const rateLimitRemaining = error.response.headers.get("ratelimit-remaining"); + const rateLimitReset = error.response.headers.get("ratelimit-reset"); + + if (rateLimitRemaining === "0" && rateLimitReset) { + const resetDate = new Date(Number(rateLimitReset) * 1000); + + if (!Number.isNaN(resetDate.getTime())) { + return { + retryAt: resetDate, + error, + }; + } + } +} diff --git a/integrations/replicate/src/models.ts b/integrations/replicate/src/models.ts new file mode 100644 index 0000000000..4436bf1c8d --- /dev/null +++ b/integrations/replicate/src/models.ts @@ -0,0 +1,79 @@ +import { IntegrationTaskKey } from "@trigger.dev/sdk"; +import { Model, ModelVersion } from "replicate"; + +import { ReplicateRunTask } from "./index"; +import { modelProperties } from "./utils"; +import { ReplicateReturnType } from "./types"; + +export class Models { + constructor(private runTask: ReplicateRunTask) {} + + get( + key: IntegrationTaskKey, + params: { + model_owner: string; + model_name: string; + } + ): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.models.get(params.model_owner, params.model_name); + }, + { + name: "Get Model", + params, + properties: modelProperties(params), + } + ); + } + + get versions() { + return new Versions(this.runTask); + } +} + +class Versions { + constructor(private runTask: ReplicateRunTask) {} + + get( + key: IntegrationTaskKey, + params: { + model_owner: string; + model_name: string; + version_id: string; + } + ): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.models.versions.get(params.model_owner, params.model_name, params.version_id); + }, + { + name: "Get Model Version", + params, + properties: modelProperties(params), + } + ); + } + + list( + key: IntegrationTaskKey, + params: { + model_owner: string; + model_name: string; + } + ): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.models.versions.list(params.model_owner, params.model_name); + }, + { + name: "List Models", + params, + properties: modelProperties(params), + } + ); + } +} diff --git a/integrations/replicate/src/predictions.ts b/integrations/replicate/src/predictions.ts new file mode 100644 index 0000000000..59567d1759 --- /dev/null +++ b/integrations/replicate/src/predictions.ts @@ -0,0 +1,92 @@ +import { IntegrationTaskKey } from "@trigger.dev/sdk"; +import ReplicateClient, { Page, Prediction } from "replicate"; + +import { ReplicateRunTask } from "./index"; +import { ReplicateReturnType } from "./types"; +import { createPredictionProperties } from "./utils"; + +export class Predictions { + constructor(private runTask: ReplicateRunTask) {} + + cancel(key: IntegrationTaskKey, params: { id: string }): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.predictions.cancel(params.id); + }, + { + name: "Cancel Prediction", + params, + properties: [{ label: "Prediction ID", text: params.id }], + } + ); + } + + create( + key: IntegrationTaskKey, + params: Parameters[0] + ): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.predictions.create(params); + }, + { + name: "Create Prediction", + params, + properties: createPredictionProperties(params), + } + ); + } + + createAndWaitForCompletion( + key: IntegrationTaskKey, + params: Omit< + Parameters[0], + "webhook" | "webhook_events_filter" + > + ): ReplicateReturnType { + return this.runTask( + key, + (client, task) => { + return client.predictions.create({ + ...params, + webhook: task.callbackUrl ?? undefined, + webhook_events_filter: ["completed"], + }); + }, + { + name: "Create And Await Prediction", + params, + properties: createPredictionProperties(params), + callback: { enabled: true }, + } + ); + } + + get(key: IntegrationTaskKey, params: { id: string }): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.predictions.get(params.id); + }, + { + name: "Get Prediction", + params, + properties: [{ label: "Prediction ID", text: params.id }], + } + ); + } + + list(key: IntegrationTaskKey): ReplicateReturnType> { + return this.runTask( + key, + (client) => { + return client.predictions.list(); + }, + { + name: "List Predictions", + } + ); + } +} diff --git a/integrations/replicate/src/trainings.ts b/integrations/replicate/src/trainings.ts new file mode 100644 index 0000000000..21cc48c017 --- /dev/null +++ b/integrations/replicate/src/trainings.ts @@ -0,0 +1,104 @@ +import { IntegrationTaskKey } from "@trigger.dev/sdk"; +import ReplicateClient, { Page, Training } from "replicate"; + +import { ReplicateRunTask } from "./index"; +import { ReplicateReturnType } from "./types"; +import { modelProperties } from "./utils"; + +export class Trainings { + constructor(private runTask: ReplicateRunTask) {} + + cancel(key: IntegrationTaskKey, params: { id: string }): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.trainings.cancel(params.id); + }, + { + name: "Cancel Training", + params, + properties: [{ label: "Training ID", text: params.id }], + } + ); + } + + create( + key: IntegrationTaskKey, + params: { + model_owner: string; + model_name: string; + version_id: string; + } & Parameters[3] + ): ReplicateReturnType { + return this.runTask( + key, + (client) => { + const { model_owner, model_name, version_id, ...options } = params; + + return client.trainings.create(model_owner, model_name, version_id, options); + }, + { + name: "Create Training", + params, + properties: modelProperties(params), + } + ); + } + + createAndWaitForCompletion( + key: IntegrationTaskKey, + params: { + model_owner: string; + model_name: string; + version_id: string; + } & Omit< + Parameters[3], + "webhook" | "webhook_events_filter" + > + ): ReplicateReturnType { + return this.runTask( + key, + (client, task) => { + const { model_owner, model_name, version_id, ...options } = params; + + return client.trainings.create(model_owner, model_name, version_id, { + ...options, + webhook: task.callbackUrl ?? undefined, + webhook_events_filter: ["completed"], + }); + }, + { + name: "Create And Await Training", + params, + properties: modelProperties(params), + callback: { enabled: true }, + } + ); + } + + get(key: IntegrationTaskKey, params: { id: string }): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.trainings.get(params.id); + }, + { + name: "Get Training", + params, + properties: [{ label: "Training ID", text: params.id }], + } + ); + } + + list(key: IntegrationTaskKey): ReplicateReturnType> { + return this.runTask( + key, + async (client) => { + return client.trainings.list(); + }, + { + name: "List Trainings", + } + ); + } +} diff --git a/integrations/replicate/src/types.ts b/integrations/replicate/src/types.ts new file mode 100644 index 0000000000..9863160718 --- /dev/null +++ b/integrations/replicate/src/types.ts @@ -0,0 +1 @@ +export type ReplicateReturnType = Promise; diff --git a/integrations/replicate/src/utils.ts b/integrations/replicate/src/utils.ts new file mode 100644 index 0000000000..0f63fef902 --- /dev/null +++ b/integrations/replicate/src/utils.ts @@ -0,0 +1,47 @@ +export const createPredictionProperties = ( + params: Partial<{ + version: string; + stream: boolean; + }> +) => { + return [ + ...(params.version ? [{ label: "Model Version", text: params.version }] : []), + ...streamingProperty(params), + ]; +}; + +export const createDeploymentProperties = ( + params: Partial<{ + deployment_owner: string; + deployment_name: string; + stream: boolean; + }> +) => { + return [ + ...(params.deployment_owner + ? [{ label: "Deployment Owner", text: params.deployment_owner }] + : []), + ...(params.deployment_name ? [{ label: "Deployment Name", text: params.deployment_name }] : []), + ...streamingProperty(params), + ]; +}; + +export const modelProperties = ( + params: Partial<{ + model_owner: string; + model_name: string; + version_id: string; + destination: string; + }> +) => { + return [ + ...(params.model_owner ? [{ label: "Model Owner", text: params.model_owner }] : []), + ...(params.model_name ? [{ label: "Model Name", text: params.model_name }] : []), + ...(params.version_id ? [{ label: "Model Version", text: params.version_id }] : []), + ...(params.destination ? [{ label: "Destination Model", text: params.destination }] : []), + ]; +}; + +export const streamingProperty = (params: { stream?: boolean }) => { + return [{ label: "Streaming Enabled", text: String(!!params.stream) }]; +}; diff --git a/integrations/replicate/tsconfig.json b/integrations/replicate/tsconfig.json new file mode 100644 index 0000000000..36ae307e42 --- /dev/null +++ b/integrations/replicate/tsconfig.json @@ -0,0 +1,4 @@ +{ + "extends": "@trigger.dev/tsconfig/integration.json", + "include": ["./src/**/*.ts", "tsup.config.ts"], +} diff --git a/integrations/replicate/tsup.config.ts b/integrations/replicate/tsup.config.ts new file mode 100644 index 0000000000..483aba1d59 --- /dev/null +++ b/integrations/replicate/tsup.config.ts @@ -0,0 +1,22 @@ +import { defineConfig } from "tsup"; + +export default defineConfig([ + { + name: "main", + entry: ["./src/index.ts"], + outDir: "./dist", + platform: "node", + format: ["cjs"], + legacyOutput: true, + sourcemap: true, + clean: true, + bundle: true, + splitting: false, + dts: true, + treeshake: { + preset: "smallest", + }, + esbuildPlugins: [], + external: ["http", "https", "util", "events", "tty", "os", "timers"], + }, +]); From a196a86a5b830820f02d8cd8cc49d215a4f6388c Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Tue, 26 Sep 2023 09:48:52 +0000 Subject: [PATCH 04/37] Basic job catalog example --- references/job-catalog/package.json | 4 ++- references/job-catalog/src/replicate.ts | 39 +++++++++++++++++++++++++ references/job-catalog/tsconfig.json | 6 ++++ 3 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 references/job-catalog/src/replicate.ts diff --git a/references/job-catalog/package.json b/references/job-catalog/package.json index 5f0384ea91..fbdf01f2e1 100644 --- a/references/job-catalog/package.json +++ b/references/job-catalog/package.json @@ -24,6 +24,7 @@ "linear": "nodemon --watch src/linear.ts -r tsconfig-paths/register -r dotenv/config src/linear.ts", "status": "nodemon --watch src/status.ts -r tsconfig-paths/register -r dotenv/config src/status.ts", "byo-auth": "nodemon --watch src/byo-auth.ts -r tsconfig-paths/register -r dotenv/config src/byo-auth.ts", + "replicate": "nodemon --watch src/replicate.ts -r tsconfig-paths/register -r dotenv/config src/replicate.ts", "dev:trigger": "trigger-cli dev --port 8080" }, "dependencies": { @@ -43,7 +44,8 @@ "@types/node": "20.4.2", "typescript": "5.1.6", "zod": "3.21.4", - "@trigger.dev/linear": "workspace:*" + "@trigger.dev/linear": "workspace:*", + "@trigger.dev/replicate": "workspace:*" }, "trigger.dev": { "endpointId": "job-catalog" diff --git a/references/job-catalog/src/replicate.ts b/references/job-catalog/src/replicate.ts new file mode 100644 index 0000000000..be3194b970 --- /dev/null +++ b/references/job-catalog/src/replicate.ts @@ -0,0 +1,39 @@ +import { createExpressServer } from "@trigger.dev/express"; +import { TriggerClient, eventTrigger } from "@trigger.dev/sdk"; +import { Replicate } from "@trigger.dev/replicate"; +import { z } from "zod"; + +export const client = new TriggerClient({ + id: "job-catalog", + apiKey: process.env["TRIGGER_API_KEY"], + apiUrl: process.env["TRIGGER_API_URL"], + verbose: false, + ioLogLocalEnabled: true, +}); + +const replicate = new Replicate({ + id: "replicate", + apiKey: process.env["REPLICATE_API_KEY"]!, +}); + +client.defineJob({ + id: "replicate-create-prediction", + name: "Replicate - Create Prediction", + version: "0.1.0", + integrations: { replicate }, + trigger: eventTrigger({ + name: "replicate.predict", + schema: z.object({ + prompt: z.string(), + version: z.string(), + }), + }), + run: async (payload, io, ctx) => { + return io.replicate.predictions.createAndWaitForCompletion("await-prediction", { + version: payload.version, + input: { prompt: payload.prompt }, + }); + }, +}); + +createExpressServer(client); diff --git a/references/job-catalog/tsconfig.json b/references/job-catalog/tsconfig.json index 6ec3167e67..80823d1a8e 100644 --- a/references/job-catalog/tsconfig.json +++ b/references/job-catalog/tsconfig.json @@ -97,6 +97,12 @@ ], "@trigger.dev/linear/*": [ "../../integrations/linear/src/*" + ], + "@trigger.dev/replicate": [ + "../../integrations/replicate/src/index" + ], + "@trigger.dev/replicate/*": [ + "../../integrations/replicate/src/*" ] } } From 0de1147bffb3ed98a2ddf13c817b977d8d978c80 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Tue, 26 Sep 2023 09:49:36 +0000 Subject: [PATCH 05/37] Integration catalog entry --- .../externalApis/integrationCatalog.server.ts | 2 + .../externalApis/integrations/replicate.ts | 50 +++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 apps/webapp/app/services/externalApis/integrations/replicate.ts diff --git a/apps/webapp/app/services/externalApis/integrationCatalog.server.ts b/apps/webapp/app/services/externalApis/integrationCatalog.server.ts index b86c2e496c..3c13f154ac 100644 --- a/apps/webapp/app/services/externalApis/integrationCatalog.server.ts +++ b/apps/webapp/app/services/externalApis/integrationCatalog.server.ts @@ -3,6 +3,7 @@ import { github } from "./integrations/github"; import { linear } from "./integrations/linear"; import { openai } from "./integrations/openai"; import { plain } from "./integrations/plain"; +import { replicate } from "./integrations/replicate"; import { resend } from "./integrations/resend"; import { sendgrid } from "./integrations/sendgrid"; import { slack } from "./integrations/slack"; @@ -37,6 +38,7 @@ export const integrationCatalog = new IntegrationCatalog({ linear, openai, plain, + replicate, resend, slack, stripe, diff --git a/apps/webapp/app/services/externalApis/integrations/replicate.ts b/apps/webapp/app/services/externalApis/integrations/replicate.ts new file mode 100644 index 0000000000..d8206f4506 --- /dev/null +++ b/apps/webapp/app/services/externalApis/integrations/replicate.ts @@ -0,0 +1,50 @@ +import type { HelpSample, Integration } from "../types"; + +function usageSample(hasApiKey: boolean): HelpSample { + const apiKeyPropertyName = "apiKey"; + + return { + title: "Using the client", + code: ` +import { Replicate } from "@trigger.dev/replicate"; + +const replicate = new Replicate({ + id: "__SLUG__",${hasApiKey ? `,\n ${apiKeyPropertyName}: process.env.REPLICATE_API_KEY!` : ""} +}); + +client.defineJob({ + id: "replicate-create-prediction", + name: "Replicate - Create Prediction", + version: "0.1.0", + integrations: { replicate }, + trigger: eventTrigger({ + name: "replicate.predict", + schema: z.object({ + prompt: z.string(), + version: z.string(), + }), + }), + run: async (payload, io, ctx) => { + return io.replicate.predictions.createAndWaitForCompletion("await-prediction", { + version: payload.version, + input: { prompt: payload.prompt }, + }); + }, +}); + `, + }; +} + +export const replicate: Integration = { + identifier: "replicate", + name: "Replicate", + packageName: "@trigger.dev/replicate@latest", + authenticationMethods: { + apikey: { + type: "apikey", + help: { + samples: [usageSample(true)], + }, + }, + }, +}; From 6f85972de7a70b4001c3d96bfb62d4328d6d351b Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Tue, 26 Sep 2023 10:21:21 +0000 Subject: [PATCH 06/37] Check for callbackUrl during executeTask --- packages/trigger-sdk/src/io.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/trigger-sdk/src/io.ts b/packages/trigger-sdk/src/io.ts index 58c8ff4338..f551a3c305 100644 --- a/packages/trigger-sdk/src/io.ts +++ b/packages/trigger-sdk/src/io.ts @@ -604,7 +604,9 @@ export class IO { }); // TODO: empty return? maybe don't even parse first - if (task.status === "WAITING") return output; + if (task.status === "WAITING" && task.callbackUrl) { + return output; + } const completedTask = await this._apiClient.completeTask(this._id, task.id, { output: output ?? undefined, From 0ab9dc1c6e2e9ae950428752d077644e0f145364 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Tue, 26 Sep 2023 10:28:23 +0000 Subject: [PATCH 07/37] Fix getAll --- integrations/replicate/src/index.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/integrations/replicate/src/index.ts b/integrations/replicate/src/index.ts index 595e4bdcbf..c68b89fe26 100644 --- a/integrations/replicate/src/index.ts +++ b/integrations/replicate/src/index.ts @@ -133,13 +133,13 @@ export class Replicate implements TriggerIntegration { task: (key: string) => Promise>, key: IntegrationTaskKey ): ReplicateReturnType { - let results: T[] = []; + const allResults: T[] = []; for await (const results of this.paginate(task, key)) { - results.concat(results); + allResults.push(...results); } - return results; + return allResults; } request( From 3e9e73ef7e1e2006858e3ac3b357807a3667e7f5 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Tue, 26 Sep 2023 10:42:45 +0000 Subject: [PATCH 08/37] Improve JSDoc --- packages/core/src/schemas/api.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/core/src/schemas/api.ts b/packages/core/src/schemas/api.ts index 127f95eace..61e80523fa 100644 --- a/packages/core/src/schemas/api.ts +++ b/packages/core/src/schemas/api.ts @@ -608,11 +608,11 @@ export const RunTaskOptionsSchema = z.object({ params: z.any(), /** The style of the log entry. */ style: StyleSchema.optional(), - /** Allows you to return the data sent to task.callbackUrl instead of the async callback return */ + /** Allows you to expose a `task.callbackUrl` to use in your tasks. Enabling this feature will cause the task to return the data sent to the callbackUrl instead of the usual async callback result. */ callback: z.object({ - /** Enable the callback feature */ + /** Causes the task to wait for and return the data of the first request sent to `task.callbackUrl`. */ enabled: z.boolean(), - /** Time to wait for callback requests */ + /** Time to wait for the first request to `task.callbackUrl`. Default: One hour. */ timeoutInSeconds: z.number(), }).partial().optional(), /** Allows you to link the Integration connection in the logs. This is handled automatically in integrations. */ From cbe31708d2ab976c92057c3f9723168d61ab923a Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Tue, 26 Sep 2023 11:32:30 +0000 Subject: [PATCH 09/37] Bump version --- integrations/replicate/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/replicate/package.json b/integrations/replicate/package.json index ac7f8aecf3..e66eb37f18 100644 --- a/integrations/replicate/package.json +++ b/integrations/replicate/package.json @@ -1,6 +1,6 @@ { "name": "@trigger.dev/replicate", - "version": "2.1.4", + "version": "2.1.5", "description": "Trigger.dev integration for replicate", "main": "./dist/index.js", "types": "./dist/index.d.ts", From 47ec5e6554ee1b87febf481aae67482ba7880688 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Tue, 26 Sep 2023 13:19:13 +0000 Subject: [PATCH 10/37] Remove named queue --- apps/webapp/app/services/worker.server.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/apps/webapp/app/services/worker.server.ts b/apps/webapp/app/services/worker.server.ts index 9cd8a73f36..114f758f34 100644 --- a/apps/webapp/app/services/worker.server.ts +++ b/apps/webapp/app/services/worker.server.ts @@ -245,7 +245,6 @@ function getWorkerQueue() { }, processCallbackTimeout: { priority: 0, // smaller number = higher priority - queueName: (payload) => `tasks:${payload.id}`, maxAttempts: 3, handler: async (payload, job) => { const service = new ProcessCallbackTimeoutService(); From 4749e10c58bc55bcbaf796d10cd07db25409ed44 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Thu, 28 Sep 2023 09:07:39 +0000 Subject: [PATCH 11/37] Simplify runTask types --- docs/integrations/create-tasks.mdx | 2 +- integrations/airtable/src/index.ts | 2 +- integrations/github/src/index.ts | 2 +- integrations/linear/src/index.ts | 2 +- integrations/replicate/src/index.ts | 2 +- integrations/replicate/src/predictions.ts | 2 +- integrations/resend/src/index.ts | 2 +- integrations/sendgrid/src/index.ts | 2 +- integrations/slack/src/index.ts | 2 +- packages/core/src/schemas/api.ts | 25 ++++++++++++++--------- packages/trigger-sdk/src/io.ts | 6 +++--- 11 files changed, 27 insertions(+), 22 deletions(-) diff --git a/docs/integrations/create-tasks.mdx b/docs/integrations/create-tasks.mdx index fe5386a181..2ed3b09dfc 100644 --- a/docs/integrations/create-tasks.mdx +++ b/docs/integrations/create-tasks.mdx @@ -24,7 +24,7 @@ export class Github implements TriggerIntegration { if (!this._io) throw new Error("No IO"); if (!this._connectionKey) throw new Error("No connection key"); - return this._io.runTask( + return this._io.runTask( key, (task, io) => { if (!this._client) throw new Error("No client"); diff --git a/integrations/airtable/src/index.ts b/integrations/airtable/src/index.ts index f8c2d99f15..b4d9d542cf 100644 --- a/integrations/airtable/src/index.ts +++ b/integrations/airtable/src/index.ts @@ -92,7 +92,7 @@ export class Airtable implements TriggerIntegration { if (!this._io) throw new Error("No IO"); if (!this._connectionKey) throw new Error("No connection key"); - return this._io.runTask( + return this._io.runTask( key, (task, io) => { if (!this._client) throw new Error("No client"); diff --git a/integrations/github/src/index.ts b/integrations/github/src/index.ts index 59e2dc968f..0716e849b5 100644 --- a/integrations/github/src/index.ts +++ b/integrations/github/src/index.ts @@ -138,7 +138,7 @@ export class Github implements TriggerIntegration { if (!this._io) throw new Error("No IO"); if (!this._connectionKey) throw new Error("No connection key"); - return this._io.runTask( + return this._io.runTask( key, (task, io) => { if (!this._client) throw new Error("No client"); diff --git a/integrations/linear/src/index.ts b/integrations/linear/src/index.ts index f6318ee7a3..473fefa044 100644 --- a/integrations/linear/src/index.ts +++ b/integrations/linear/src/index.ts @@ -158,7 +158,7 @@ export class Linear implements TriggerIntegration { if (!this._io) throw new Error("No IO"); if (!this._connectionKey) throw new Error("No connection key"); - return this._io.runTask( + return this._io.runTask( key, (task, io) => { if (!this._client) throw new Error("No client"); diff --git a/integrations/replicate/src/index.ts b/integrations/replicate/src/index.ts index c68b89fe26..8651eecd03 100644 --- a/integrations/replicate/src/index.ts +++ b/integrations/replicate/src/index.ts @@ -73,7 +73,7 @@ export class Replicate implements TriggerIntegration { if (!this._io) throw new Error("No IO"); if (!this._connectionKey) throw new Error("No connection key"); - return this._io.runTask( + return this._io.runTask( key, (task, io) => { if (!this._client) throw new Error("No client"); diff --git a/integrations/replicate/src/predictions.ts b/integrations/replicate/src/predictions.ts index 59567d1759..8ff97fed5b 100644 --- a/integrations/replicate/src/predictions.ts +++ b/integrations/replicate/src/predictions.ts @@ -45,7 +45,7 @@ export class Predictions { Parameters[0], "webhook" | "webhook_events_filter" > - ): ReplicateReturnType { + ): ReplicateReturnType { return this.runTask( key, (client, task) => { diff --git a/integrations/resend/src/index.ts b/integrations/resend/src/index.ts index f93be1964d..b483fa48c9 100644 --- a/integrations/resend/src/index.ts +++ b/integrations/resend/src/index.ts @@ -100,7 +100,7 @@ export class Resend implements TriggerIntegration { if (!this._io) throw new Error("No IO"); if (!this._connectionKey) throw new Error("No connection key"); - return this._io.runTask( + return this._io.runTask( key, (task, io) => { if (!this._client) throw new Error("No client"); diff --git a/integrations/sendgrid/src/index.ts b/integrations/sendgrid/src/index.ts index 49a1a55f2d..d08a39379b 100644 --- a/integrations/sendgrid/src/index.ts +++ b/integrations/sendgrid/src/index.ts @@ -70,7 +70,7 @@ export class SendGrid implements TriggerIntegration { if (!this._io) throw new Error("No IO"); if (!this._connectionKey) throw new Error("No connection key"); - return this._io.runTask( + return this._io.runTask( key, (task, io) => { if (!this._client) throw new Error("No client"); diff --git a/integrations/slack/src/index.ts b/integrations/slack/src/index.ts index 9a8571d4ce..4934a99d1c 100644 --- a/integrations/slack/src/index.ts +++ b/integrations/slack/src/index.ts @@ -92,7 +92,7 @@ export class Slack implements TriggerIntegration { if (!this._io) throw new Error("No IO"); if (!this._connectionKey) throw new Error("No connection key"); - return this._io.runTask( + return this._io.runTask( key, (task, io) => { if (!this._client) throw new Error("No client"); diff --git a/packages/core/src/schemas/api.ts b/packages/core/src/schemas/api.ts index 61e80523fa..607dee5d50 100644 --- a/packages/core/src/schemas/api.ts +++ b/packages/core/src/schemas/api.ts @@ -609,12 +609,15 @@ export const RunTaskOptionsSchema = z.object({ /** The style of the log entry. */ style: StyleSchema.optional(), /** Allows you to expose a `task.callbackUrl` to use in your tasks. Enabling this feature will cause the task to return the data sent to the callbackUrl instead of the usual async callback result. */ - callback: z.object({ - /** Causes the task to wait for and return the data of the first request sent to `task.callbackUrl`. */ - enabled: z.boolean(), - /** Time to wait for the first request to `task.callbackUrl`. Default: One hour. */ - timeoutInSeconds: z.number(), - }).partial().optional(), + callback: z + .object({ + /** Causes the task to wait for and return the data of the first request sent to `task.callbackUrl`. */ + enabled: z.boolean(), + /** Time to wait for the first request to `task.callbackUrl`. Default: One hour. */ + timeoutInSeconds: z.number(), + }) + .partial() + .optional(), /** Allows you to link the Integration connection in the logs. This is handled automatically in integrations. */ connectionKey: z.string().optional(), /** An operation you want to perform on the Trigger.dev platform, current only "fetch" is supported. If you wish to `fetch` use [`io.backgroundFetch()`](https://trigger.dev/docs/sdk/io/backgroundfetch) instead. */ @@ -641,10 +644,12 @@ export type RunTaskBodyInput = z.infer; export const RunTaskBodyOutputSchema = RunTaskBodyInputSchema.extend({ params: DeserializedJsonSchema.optional().nullable(), - callback: z.object({ - enabled: z.boolean(), - timeoutInSeconds: z.number().default(3600), - }).optional(), + callback: z + .object({ + enabled: z.boolean(), + timeoutInSeconds: z.number().default(3600), + }) + .optional(), }); export type RunTaskBodyOutput = z.infer; diff --git a/packages/trigger-sdk/src/io.ts b/packages/trigger-sdk/src/io.ts index f551a3c305..fa5a8afef3 100644 --- a/packages/trigger-sdk/src/io.ts +++ b/packages/trigger-sdk/src/io.ts @@ -526,12 +526,12 @@ export class IO { * @param onError The callback that will be called when the Task fails. The callback receives the error, the Task and the IO as parameters. If you wish to retry then return an object with a `retryAt` property. * @returns A Promise that resolves with the returned value of the callback. */ - async runTask | void, TOptions extends RunTaskOptions = RunTaskOptions>( + async runTask | void>( key: string | any[], callback: (task: ServerTask, io: IO) => Promise, - options?: TOptions, + options?: RunTaskOptions, onError?: RunTaskErrorCallback - ): Promise { + ): Promise { const parentId = this._taskStorage.getStore()?.taskId; if (parentId) { From ded84db151831bd262519d19e02a251407e70b97 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Thu, 28 Sep 2023 09:13:28 +0000 Subject: [PATCH 12/37] Trust the types --- .../app/routes/api.v1.runs.$runId.tasks.$id.callback.$secret.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/webapp/app/routes/api.v1.runs.$runId.tasks.$id.callback.$secret.ts b/apps/webapp/app/routes/api.v1.runs.$runId.tasks.$id.callback.$secret.ts index 3925d830dd..7402c3a758 100644 --- a/apps/webapp/app/routes/api.v1.runs.$runId.tasks.$id.callback.$secret.ts +++ b/apps/webapp/app/routes/api.v1.runs.$runId.tasks.$id.callback.$secret.ts @@ -36,7 +36,7 @@ export async function action({ request, params }: ActionArgs) { const service = new CallbackRunTaskService(); try { - await service.call(runId, id, body.data, new URL(request.url).href); + await service.call(runId, id, body.data, request.url); return json({ success: true }); } catch (error) { From b0a1fdd9099aaf6c84d64085ebf63a695245d1fe Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Thu, 28 Sep 2023 09:43:28 +0000 Subject: [PATCH 13/37] Fail tasks on timeout --- .../app/services/tasks/processCallbackTimeout.ts | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/apps/webapp/app/services/tasks/processCallbackTimeout.ts b/apps/webapp/app/services/tasks/processCallbackTimeout.ts index e4528935b9..948691990d 100644 --- a/apps/webapp/app/services/tasks/processCallbackTimeout.ts +++ b/apps/webapp/app/services/tasks/processCallbackTimeout.ts @@ -1,4 +1,4 @@ -import { RuntimeEnvironmentType, type Task } from "@trigger.dev/database"; +import { RuntimeEnvironmentType } from "@trigger.dev/database"; import { $transaction, PrismaClient, PrismaClientOrTransaction, prisma } from "~/db.server"; import { enqueueRunExecutionV2 } from "~/models/jobRunExecution.server"; import { logger } from "../logger.server"; @@ -25,10 +25,10 @@ export class ProcessCallbackTimeoutService { logger.debug("ProcessCallbackTimeoutService.call", { task }); - return await this.#resumeTask(task, null); + return await this.#failTask(task, "Remote callback timeout - no requests received"); } - async #resumeTask(task: NonNullable, output: any) { + async #failTask(task: NonNullable, error: string) { await $transaction(this.#prismaClient, async (tx) => { await tx.taskAttempt.updateMany({ where: { @@ -36,16 +36,17 @@ export class ProcessCallbackTimeoutService { status: "PENDING", }, data: { - status: "COMPLETED", + status: "ERRORED", + error }, }); await tx.task.update({ where: { id: task.id }, data: { - status: "COMPLETED", + status: "ERRORED", completedAt: new Date(), - output: output ? output : undefined, + output: error, }, }); From a53debbb2ab28871ddad575e2a3895b352507387 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Thu, 28 Sep 2023 10:04:03 +0000 Subject: [PATCH 14/37] Callback timeout as param --- integrations/replicate/src/deployments.ts | 6 +++--- integrations/replicate/src/predictions.ts | 6 +++--- integrations/replicate/src/trainings.ts | 6 +++--- integrations/replicate/src/utils.ts | 9 +++++++++ 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/integrations/replicate/src/deployments.ts b/integrations/replicate/src/deployments.ts index e6ebd68a58..1b1c14b6c0 100644 --- a/integrations/replicate/src/deployments.ts +++ b/integrations/replicate/src/deployments.ts @@ -2,7 +2,7 @@ import { IntegrationTaskKey } from "@trigger.dev/sdk"; import ReplicateClient, { Prediction } from "replicate"; import { ReplicateRunTask } from "./index"; -import { createDeploymentProperties } from "./utils"; +import { callbackProperties, createDeploymentProperties } from "./utils"; import { ReplicateReturnType } from "./types"; export class Deployments { @@ -46,7 +46,7 @@ class Predictions { } & Omit< Parameters[2], "webhook" | "webhook_events_filter" - > + > & { timeoutInSeconds?: number } ): ReplicateReturnType { return this.runTask( key, @@ -62,7 +62,7 @@ class Predictions { { name: "Create And Await Prediction With Deployment", params, - properties: createDeploymentProperties(params), + properties: [...createDeploymentProperties(params), ...callbackProperties(params)], callback: { enabled: true }, } ); diff --git a/integrations/replicate/src/predictions.ts b/integrations/replicate/src/predictions.ts index 8ff97fed5b..212e84befa 100644 --- a/integrations/replicate/src/predictions.ts +++ b/integrations/replicate/src/predictions.ts @@ -3,7 +3,7 @@ import ReplicateClient, { Page, Prediction } from "replicate"; import { ReplicateRunTask } from "./index"; import { ReplicateReturnType } from "./types"; -import { createPredictionProperties } from "./utils"; +import { callbackProperties, createPredictionProperties } from "./utils"; export class Predictions { constructor(private runTask: ReplicateRunTask) {} @@ -44,7 +44,7 @@ export class Predictions { params: Omit< Parameters[0], "webhook" | "webhook_events_filter" - > + > & { timeoutInSeconds?: number } ): ReplicateReturnType { return this.runTask( key, @@ -58,7 +58,7 @@ export class Predictions { { name: "Create And Await Prediction", params, - properties: createPredictionProperties(params), + properties: [...createPredictionProperties(params), ...callbackProperties(params)], callback: { enabled: true }, } ); diff --git a/integrations/replicate/src/trainings.ts b/integrations/replicate/src/trainings.ts index 21cc48c017..75f32b179e 100644 --- a/integrations/replicate/src/trainings.ts +++ b/integrations/replicate/src/trainings.ts @@ -3,7 +3,7 @@ import ReplicateClient, { Page, Training } from "replicate"; import { ReplicateRunTask } from "./index"; import { ReplicateReturnType } from "./types"; -import { modelProperties } from "./utils"; +import { callbackProperties, modelProperties } from "./utils"; export class Trainings { constructor(private runTask: ReplicateRunTask) {} @@ -54,7 +54,7 @@ export class Trainings { } & Omit< Parameters[3], "webhook" | "webhook_events_filter" - > + > & { timeoutInSeconds?: number } ): ReplicateReturnType { return this.runTask( key, @@ -70,7 +70,7 @@ export class Trainings { { name: "Create And Await Training", params, - properties: modelProperties(params), + properties: [...modelProperties(params), ...callbackProperties(params)], callback: { enabled: true }, } ); diff --git a/integrations/replicate/src/utils.ts b/integrations/replicate/src/utils.ts index 0f63fef902..ebdf9eccc8 100644 --- a/integrations/replicate/src/utils.ts +++ b/integrations/replicate/src/utils.ts @@ -45,3 +45,12 @@ export const modelProperties = ( export const streamingProperty = (params: { stream?: boolean }) => { return [{ label: "Streaming Enabled", text: String(!!params.stream) }]; }; + +export const callbackProperties = (params: { timeoutInSeconds?: number }) => { + return [ + { + label: "Callback Timeout", + text: params.timeoutInSeconds ? `${params.timeoutInSeconds}s` : "default", + }, + ]; +}; From d80b13afc0b70ce191d5db43e05aa60c55ec033b Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Thu, 28 Sep 2023 10:29:37 +0000 Subject: [PATCH 15/37] Mess with types --- integrations/replicate/src/deployments.ts | 2 +- integrations/replicate/src/index.ts | 12 +++++++++--- integrations/replicate/src/predictions.ts | 2 +- integrations/replicate/src/trainings.ts | 2 +- packages/core/src/schemas/api.ts | 2 ++ packages/trigger-sdk/src/io.ts | 6 ++++-- packages/trigger-sdk/src/types.ts | 2 ++ 7 files changed, 20 insertions(+), 8 deletions(-) diff --git a/integrations/replicate/src/deployments.ts b/integrations/replicate/src/deployments.ts index 1b1c14b6c0..a6ec577cda 100644 --- a/integrations/replicate/src/deployments.ts +++ b/integrations/replicate/src/deployments.ts @@ -55,7 +55,7 @@ class Predictions { return client.deployments.predictions.create(deployment_owner, deployment_name, { ...options, - webhook: task.callbackUrl ?? undefined, + webhook: task.callbackUrl, webhook_events_filter: ["completed"], }); }, diff --git a/integrations/replicate/src/index.ts b/integrations/replicate/src/index.ts index 8651eecd03..1ffbe62056 100644 --- a/integrations/replicate/src/index.ts +++ b/integrations/replicate/src/index.ts @@ -1,6 +1,7 @@ import { TriggerIntegration, RunTaskOptions, + RunTaskOptionsWithCallback, IO, IOTask, IntegrationTaskKey, @@ -8,6 +9,7 @@ import { Json, retry, ConnectionAuth, + IOTaskWithCallback, } from "@trigger.dev/sdk"; import ReplicateClient, { Page } from "replicate"; @@ -64,10 +66,14 @@ export class Replicate implements TriggerIntegration { }); } - runTask | void>( + runTask | void, TOptions extends RunTaskOptions>( key: IntegrationTaskKey, - callback: (client: ReplicateClient, task: IOTask, io: IO) => Promise, - options?: RunTaskOptions, + callback: ( + client: ReplicateClient, + task: TOptions extends RunTaskOptionsWithCallback ? IOTaskWithCallback : IOTask, + io: IO + ) => Promise, + options?: TOptions, errorCallback?: RunTaskErrorCallback ): Promise { if (!this._io) throw new Error("No IO"); diff --git a/integrations/replicate/src/predictions.ts b/integrations/replicate/src/predictions.ts index 212e84befa..8a82bfc183 100644 --- a/integrations/replicate/src/predictions.ts +++ b/integrations/replicate/src/predictions.ts @@ -51,7 +51,7 @@ export class Predictions { (client, task) => { return client.predictions.create({ ...params, - webhook: task.callbackUrl ?? undefined, + webhook: task.callbackUrl, webhook_events_filter: ["completed"], }); }, diff --git a/integrations/replicate/src/trainings.ts b/integrations/replicate/src/trainings.ts index 75f32b179e..98fd69234e 100644 --- a/integrations/replicate/src/trainings.ts +++ b/integrations/replicate/src/trainings.ts @@ -63,7 +63,7 @@ export class Trainings { return client.trainings.create(model_owner, model_name, version_id, { ...options, - webhook: task.callbackUrl ?? undefined, + webhook: task.callbackUrl, webhook_events_filter: ["completed"], }); }, diff --git a/packages/core/src/schemas/api.ts b/packages/core/src/schemas/api.ts index 607dee5d50..b524a27481 100644 --- a/packages/core/src/schemas/api.ts +++ b/packages/core/src/schemas/api.ts @@ -630,6 +630,8 @@ export const RunTaskOptionsSchema = z.object({ export type RunTaskOptions = z.input; +export type RunTaskOptionsWithCallback = RunTaskOptions & { callback: { enabled: true } }; + export type OverridableRunTaskOptions = Pick< RunTaskOptions, "retry" | "delayUntil" | "description" diff --git a/packages/trigger-sdk/src/io.ts b/packages/trigger-sdk/src/io.ts index fa5a8afef3..8c9ea5d7c7 100644 --- a/packages/trigger-sdk/src/io.ts +++ b/packages/trigger-sdk/src/io.ts @@ -36,6 +36,8 @@ import { TriggerStatus } from "./status"; export type IOTask = ServerTask; +export type IOTaskWithCallback = IOTask & { callbackUrl: string }; + export type IOOptions = { id: string; apiClient: ApiClient; @@ -528,7 +530,7 @@ export class IO { */ async runTask | void>( key: string | any[], - callback: (task: ServerTask, io: IO) => Promise, + callback: (task: ServerTask & { callbackUrl: string }, io: IO) => Promise, options?: RunTaskOptions, onError?: RunTaskErrorCallback ): Promise { @@ -594,7 +596,7 @@ export class IO { const executeTask = async () => { try { - const result = await callback(task, this); + const result = await callback(task as ServerTask & { callbackUrl: string }, this); const output = SerializableJsonSchema.parse(result) as T; diff --git a/packages/trigger-sdk/src/types.ts b/packages/trigger-sdk/src/types.ts index e7938b236e..da5ad78f96 100644 --- a/packages/trigger-sdk/src/types.ts +++ b/packages/trigger-sdk/src/types.ts @@ -7,6 +7,7 @@ import type { RedactString, RegisteredOptionsDiff, RunTaskOptions, + RunTaskOptionsWithCallback, RuntimeEnvironmentType, SourceEventOption, TriggerMetadata, @@ -22,6 +23,7 @@ export type { RedactString, RegisteredOptionsDiff, RunTaskOptions, + RunTaskOptionsWithCallback, SourceEventOption, }; From 3f278e87a059bbbd53514f5d6fe12ad5f2e903b6 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Thu, 28 Sep 2023 10:32:10 +0000 Subject: [PATCH 16/37] performRunExecutionV1 --- apps/webapp/app/services/runs/performRunExecutionV1.server.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/apps/webapp/app/services/runs/performRunExecutionV1.server.ts b/apps/webapp/app/services/runs/performRunExecutionV1.server.ts index 591f6f00b1..023b368695 100644 --- a/apps/webapp/app/services/runs/performRunExecutionV1.server.ts +++ b/apps/webapp/app/services/runs/performRunExecutionV1.server.ts @@ -409,7 +409,9 @@ export class PerformRunExecutionV1Service { // If the task has an operation, then the next performRunExecution will occur // when that operation has finished - if (!data.task.operation) { + // Tasks with callbacks enabled will also get processed separately, i.e. when + // they time out, or on valid requests to their callbackUrl + if (!data.task.operation && !data.task.callbackUrl) { const newJobExecution = await tx.jobRunExecution.create({ data: { runId: run.id, From 043ae42d6dc932b8cf23f5b72726e66a3529a7cf Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Thu, 28 Sep 2023 11:00:02 +0000 Subject: [PATCH 17/37] Update runTask docs --- docs/sdk/io/runtask.mdx | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/sdk/io/runtask.mdx b/docs/sdk/io/runtask.mdx index 5791429980..2eb9cea54d 100644 --- a/docs/sdk/io/runtask.mdx +++ b/docs/sdk/io/runtask.mdx @@ -112,6 +112,22 @@ A Task is a resumable unit of a Run that can be retried, resumed and is logged. + + + An optional object that exposes settings for the remote callback feature. + + Enabling this feature will expose a `callbackUrl` property on the callback's Task parameter. Additionally, `io.runTask()` will now return a Promise that resolves with the body of the first request sent to that URL. + + + + Whether to enable the remote callback feature. + + + The value of the property. + + + + @@ -133,6 +149,8 @@ A Task is a resumable unit of a Run that can be retried, resumed and is logged. A Promise that resolves with the returned value of the callback. +If the remote callback feature `options.callback` is enabled, the Promise will instead resolve with the body of the first request sent to `task.callbackUrl`. + ```typescript Run a task From 259bd7ccbd9465104ad5715c288f8acc753a66c0 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Thu, 28 Sep 2023 14:11:12 +0000 Subject: [PATCH 18/37] Shorten callback task methods --- apps/webapp/app/services/externalApis/integrations/replicate.ts | 2 +- integrations/replicate/src/deployments.ts | 2 +- integrations/replicate/src/index.ts | 2 +- integrations/replicate/src/predictions.ts | 2 +- integrations/replicate/src/trainings.ts | 2 +- references/job-catalog/src/replicate.ts | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/apps/webapp/app/services/externalApis/integrations/replicate.ts b/apps/webapp/app/services/externalApis/integrations/replicate.ts index d8206f4506..74f20cdafd 100644 --- a/apps/webapp/app/services/externalApis/integrations/replicate.ts +++ b/apps/webapp/app/services/externalApis/integrations/replicate.ts @@ -25,7 +25,7 @@ client.defineJob({ }), }), run: async (payload, io, ctx) => { - return io.replicate.predictions.createAndWaitForCompletion("await-prediction", { + return io.replicate.predictions.createAndAwait("await-prediction", { version: payload.version, input: { prompt: payload.prompt }, }); diff --git a/integrations/replicate/src/deployments.ts b/integrations/replicate/src/deployments.ts index a6ec577cda..e6f132f22c 100644 --- a/integrations/replicate/src/deployments.ts +++ b/integrations/replicate/src/deployments.ts @@ -38,7 +38,7 @@ class Predictions { ); } - createAndWaitForCompletion( + createAndAwait( key: IntegrationTaskKey, params: { deployment_owner: string; diff --git a/integrations/replicate/src/index.ts b/integrations/replicate/src/index.ts index 1ffbe62056..91a889903a 100644 --- a/integrations/replicate/src/index.ts +++ b/integrations/replicate/src/index.ts @@ -200,7 +200,7 @@ export class Replicate implements TriggerIntegration { const { version } = match.groups; - return this.predictions.createAndWaitForCompletion(key, { ...options, version }); + return this.predictions.createAndAwait(key, { ...options, version }); } // TODO: wait(prediction) - needs polling diff --git a/integrations/replicate/src/predictions.ts b/integrations/replicate/src/predictions.ts index 8a82bfc183..44bdcc9eb2 100644 --- a/integrations/replicate/src/predictions.ts +++ b/integrations/replicate/src/predictions.ts @@ -39,7 +39,7 @@ export class Predictions { ); } - createAndWaitForCompletion( + createAndAwait( key: IntegrationTaskKey, params: Omit< Parameters[0], diff --git a/integrations/replicate/src/trainings.ts b/integrations/replicate/src/trainings.ts index 98fd69234e..a1e0b3002d 100644 --- a/integrations/replicate/src/trainings.ts +++ b/integrations/replicate/src/trainings.ts @@ -45,7 +45,7 @@ export class Trainings { ); } - createAndWaitForCompletion( + createAndAwait( key: IntegrationTaskKey, params: { model_owner: string; diff --git a/references/job-catalog/src/replicate.ts b/references/job-catalog/src/replicate.ts index be3194b970..3cb21ec7d4 100644 --- a/references/job-catalog/src/replicate.ts +++ b/references/job-catalog/src/replicate.ts @@ -29,7 +29,7 @@ client.defineJob({ }), }), run: async (payload, io, ctx) => { - return io.replicate.predictions.createAndWaitForCompletion("await-prediction", { + return io.replicate.predictions.createAndAwait("await-prediction", { version: payload.version, input: { prompt: payload.prompt }, }); From 04ae90683b8637adb4503c4b59653641836749a2 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Fri, 29 Sep 2023 07:20:19 +0000 Subject: [PATCH 19/37] Fix run method return type --- integrations/replicate/src/index.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/replicate/src/index.ts b/integrations/replicate/src/index.ts index 91a889903a..7cbd9c2e6f 100644 --- a/integrations/replicate/src/index.ts +++ b/integrations/replicate/src/index.ts @@ -11,7 +11,7 @@ import { ConnectionAuth, IOTaskWithCallback, } from "@trigger.dev/sdk"; -import ReplicateClient, { Page } from "replicate"; +import ReplicateClient, { Page, Prediction } from "replicate"; import { Predictions } from "./predictions"; import { Models } from "./models"; @@ -182,7 +182,7 @@ export class Replicate implements TriggerIntegration { Parameters[1], "webhook" | "webhook_events_filter" | "wait" | "signal" > - ): ReplicateReturnType { + ): ReplicateReturnType { const { identifier, ...options } = params; // see: https://github.com/replicate/replicate-javascript/blob/4b0d9cb0e226fab3d3d31de5b32261485acf5626/index.js#L102 From 6f739cac6a3f176363608e7597107a0228770310 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Fri, 29 Sep 2023 07:20:46 +0000 Subject: [PATCH 20/37] Image processing jobs --- references/job-catalog/src/replicate.ts | 95 +++++++++++++++++++++++-- 1 file changed, 88 insertions(+), 7 deletions(-) diff --git a/references/job-catalog/src/replicate.ts b/references/job-catalog/src/replicate.ts index 3cb21ec7d4..7aca9396ca 100644 --- a/references/job-catalog/src/replicate.ts +++ b/references/job-catalog/src/replicate.ts @@ -17,22 +17,103 @@ const replicate = new Replicate({ }); client.defineJob({ - id: "replicate-create-prediction", - name: "Replicate - Create Prediction", + id: "replicate-forge-image", + name: "Replicate - Forge Image", version: "0.1.0", integrations: { replicate }, trigger: eventTrigger({ - name: "replicate.predict", + name: "replicate.bad.forgery", schema: z.object({ - prompt: z.string(), - version: z.string(), + imageUrl: z + .string() + .url() + .default("https://trigger.dev/blog/supabase-integration/postgres-meme.png"), }), }), run: async (payload, io, ctx) => { - return io.replicate.predictions.createAndAwait("await-prediction", { + const blipVersion = "2e1dddc8621f72155f24cf2e0adbde548458d3cab9f00c0139eea840d0ac4746"; + const sdVersion = "ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4"; + + const blipPrediction = await io.replicate.run("caption-image", { + identifier: `salesforce/blip:${blipVersion}`, + input: { + image: payload.imageUrl, + }, + }); + + if (typeof blipPrediction.output !== "string") { + throw new Error(`Expected string output, got ${typeof blipPrediction.output}`); + } + + const caption = blipPrediction.output.replace("Caption: ", ""); + + const sdPrediction = await io.replicate.predictions.createAndAwait("draw-image", { + version: sdVersion, + input: { + prompt: caption, + }, + }); + + return { + caption, + output: sdPrediction.output, + }; + }, +}); + +client.defineJob({ + id: "replicate-cinematic-prompt", + name: "Replicate - Cinematic Prompt", + version: "0.1.0", + integrations: { replicate }, + trigger: eventTrigger({ + name: "replicate.cinematic", + schema: z.object({ + prompt: z.string().default("rick astley riding a harley through post-apocalyptic miami"), + version: z + .string() + .default("af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33"), + }), + }), + run: async (payload, io, ctx) => { + const prediction = await io.replicate.predictions.createAndAwait("await-prediction", { version: payload.version, - input: { prompt: payload.prompt }, + input: { + prompt: `${payload.prompt}, cinematic, 70mm, anamorphic, bokeh`, + width: 1280, + height: 720, + }, }); + return prediction.output; + }, +}); + +client.defineJob({ + id: "replicate-pagination", + name: "Replicate - Pagination", + version: "0.1.0", + integrations: { + replicate, + }, + trigger: eventTrigger({ + name: "replicate.paginate", + }), + run: async (payload, io, ctx) => { + // getAll - returns an array of all results (uses paginate internally) + const all = await io.replicate.getAll(io.replicate.predictions.list, "get-all"); + + // paginate - returns an async generator, useful to process one page at a time + for await (const predictions of io.replicate.paginate( + io.replicate.predictions.list, + "paginate-all" + )) { + await io.logger.info("stats", { + total: predictions.length, + versions: predictions.map((p) => p.version), + }); + } + + return { count: all.length }; }, }); From 49fdcd02499b1bc9538b6359d1a89cb4301bad49 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Fri, 29 Sep 2023 07:41:11 +0000 Subject: [PATCH 21/37] Replicate docs --- docs/integrations/apis/replicate.mdx | 163 +++++++++++++++++++++++++++ docs/integrations/introduction.mdx | 23 ++-- docs/mint.json | 1 + 3 files changed, 176 insertions(+), 11 deletions(-) create mode 100644 docs/integrations/apis/replicate.mdx diff --git a/docs/integrations/apis/replicate.mdx b/docs/integrations/apis/replicate.mdx new file mode 100644 index 0000000000..ef4dd7958e --- /dev/null +++ b/docs/integrations/apis/replicate.mdx @@ -0,0 +1,163 @@ +--- +title: Replicate +description: "Run machine learning tasks easily at scale" +--- + + + +## Installation + +To get started with the Replicate integration on Trigger.dev, you need to install the `@trigger.dev/replicate` package. +You can do this using npm, pnpm, or yarn: + + + +```bash npm +npm install @trigger.dev/replicate@latest +``` + +```bash pnpm +pnpm add @trigger.dev/replicate@latest +``` + +```bash yarn +yarn add @trigger.dev/replicate@latest +``` + + + +## Authentication + +To use the Replicate API with Trigger.dev, you have to provide an API Key. + +### API Key + +You can create an API Key in your [Account Settings](https://replicate.com/account/api-tokens). + +```ts +import { Replicate } from "@trigger.dev/replicate"; + +//this will use the passed in API key (defined in your environment variables) +const replicate = new Replicate({ + id: "replicate", + apiKey: process.env["REPLICATE_API_KEY"], +}); +``` + +## Usage + +Include the Replicate integration in your Trigger.dev job. + +```ts +client.defineJob({ + id: "replicate-cinematic-prompt", + name: "Replicate - Cinematic Prompt", + version: "0.1.0", + integrations: { replicate }, + trigger: eventTrigger({ + name: "replicate.cinematic", + schema: z.object({ + prompt: z.string().default("rick astley riding a harley through post-apocalyptic miami"), + version: z + .string() + .default("af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33"), + }), + }), + run: async (payload, io, ctx) => { + //wait for prediction completion (uses remote callbacks internally) + const prediction = await io.replicate.predictions.createAndAwait("await-prediction", { + version: payload.version, + input: { + prompt: `${payload.prompt}, cinematic, 70mm, anamorphic, bokeh`, + width: 1280, + height: 720, + }, + }); + return prediction.output; + }, +}); +``` + +### Pagination + +You can paginate responses: + +- Using the `getAll` helper +- Using the `paginate` helper + +```ts +client.defineJob({ + id: "replicate-pagination", + name: "Replicate Pagination", + version: "0.1.0", + integrations: { + replicate, + }, + trigger: eventTrigger({ + name: "replicate.paginate", + }), + run: async (payload, io, ctx) => { + // getAll - returns an array of all results (uses paginate internally) + const all = await io.replicate.getAll(io.replicate.predictions.list, "get-all"); + + // paginate - returns an async generator, useful to process one page at a time + for await (const predictions of io.replicate.paginate( + io.replicate.predictions.list, + "paginate-all" + )) { + await io.logger.info("stats", { + total: predictions.length, + versions: predictions.map((p) => p.version), + }); + } + + return { count: all.length }; + }, +}); +``` + +## Tasks + +### Collections + +| Function Name | Description | +| ------------------ | ---------------------------------------------------------------------- | +| `collections.get` | Gets a collection. | +| `collections.list` | Returns the first page of all collections. Use with pagination helper. | + +### Models + +| Function Name | Description | +| ----------------- | ------------------------ | +| `models.get` | Gets a model. | +| `models.versions` | Gets a model version. | +| `models.versions` | Gets all model versions. | + +### Predictions + +| Function Name | Description | +| ---------------------------- | ---------------------------------------------------------------------- | +| `predictions.cancel` | Cancels a prediction. | +| `predictions.create` | Creates a prediction. | +| `predictions.createAndAwait` | Creates and waits for a prediction. | +| `predictions.get` | Gets a prediction. | +| `predictions.list` | Returns the first page of all predictions. Use with pagination helper. | + +### Trainings + +| Function Name | Description | +| -------------------------- | -------------------------------------------------------------------- | +| `trainings.cancel` | Cancels a training. | +| `trainings.create` | Creates a training. | +| `trainings.createAndAwait` | Creates and waits for a training. | +| `trainings.get` | Gets a training. | +| `trainings.list` | Returns the first page of all trainings. Use with pagination helper. | + +### Misc + +| Function Name | Description | +| ------------- | --------------------------------------------------- | +| `getAll` | Pagination helper that returns an array of results. | +| `paginate` | Pagination helper that returns an async generator. | +| `request` | Sends authenticated requests to the Replicate API. | +| `run` | Creates and waits for a prediction. | diff --git a/docs/integrations/introduction.mdx b/docs/integrations/introduction.mdx index 30be9c1e36..688463510c 100644 --- a/docs/integrations/introduction.mdx +++ b/docs/integrations/introduction.mdx @@ -30,14 +30,15 @@ description: "Integrations make it easy to authenticate and use APIs." Navigate the menu or select Integrations from the table below. -| API | Description | Webhooks | Tasks | -| --------------------------------------- | ---------------------------------------------------------------- | -------- | ----- | -| [GitHub](/integrations/apis/github) | Subscribe to webhooks and perform actions | ✅ | ✅ | -| [Linear](/integrations/apis/linear) | Streamline project and issue tracking | ✅ | ✅ | -| [OpenAI](/integrations/apis/openai) | Generate text and images. Including longer than 30s prompts | N/A | ✅ | -| [Plain](/integrations/apis/plain) | Perform customer support using Plain | 🕘 | ✅ | -| [Resend](/integrations/apis/resend) | Send emails using Resend | 🕘 | ✅ | -| [SendGrid](/integrations/apis/sendgrid) | Send emails using SendGrid | 🕘 | ✅ | -| [Slack](/integrations/apis/slack) | Send Slack messages | 🕘 | ✅ | -| [Supabase](/integrations/apis/supabase) | Interact with your projects and databases | ✅ | ✅ | -| [Typeform](/integrations/apis/typeform) | Interact with the Typeform API and get notified of new responses | ✅ | ✅ | +| API | Description | Webhooks | Tasks | +| ----------------------------------------- | ---------------------------------------------------------------- | -------- | ----- | +| [GitHub](/integrations/apis/github) | Subscribe to webhooks and perform actions | ✅ | ✅ | +| [Linear](/integrations/apis/linear) | Streamline project and issue tracking | ✅ | ✅ | +| [OpenAI](/integrations/apis/openai) | Generate text and images. Including longer than 30s prompts | N/A | ✅ | +| [Plain](/integrations/apis/plain) | Perform customer support using Plain | 🕘 | ✅ | +| [Replicate](/integrations/apis/replicate) | Run machine learning tasks easily at scale | N/A | ✅ | +| [Resend](/integrations/apis/resend) | Send emails using Resend | 🕘 | ✅ | +| [SendGrid](/integrations/apis/sendgrid) | Send emails using SendGrid | 🕘 | ✅ | +| [Slack](/integrations/apis/slack) | Send Slack messages | 🕘 | ✅ | +| [Supabase](/integrations/apis/supabase) | Interact with your projects and databases | ✅ | ✅ | +| [Typeform](/integrations/apis/typeform) | Interact with the Typeform API and get notified of new responses | ✅ | ✅ | diff --git a/docs/mint.json b/docs/mint.json index 6c041b32bc..af00f3e432 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -247,6 +247,7 @@ "integrations/apis/linear", "integrations/apis/openai", "integrations/apis/plain", + "integrations/apis/replicate", "integrations/apis/resend", "integrations/apis/sendgrid", "integrations/apis/slack", From 9054750d314f8ee870b71baa97a7b022de6629a4 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Fri, 29 Sep 2023 07:53:09 +0000 Subject: [PATCH 22/37] Text output example --- references/job-catalog/src/replicate.ts | 26 +++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/references/job-catalog/src/replicate.ts b/references/job-catalog/src/replicate.ts index 7aca9396ca..5cd416bf08 100644 --- a/references/job-catalog/src/replicate.ts +++ b/references/job-catalog/src/replicate.ts @@ -61,6 +61,32 @@ client.defineJob({ }, }); +client.defineJob({ + id: "replicate-python-answers", + name: "Replicate - Python Answers", + version: "0.1.0", + integrations: { replicate }, + trigger: eventTrigger({ + name: "replicate.serious.monty", + schema: z.object({ + prompt: z.string().default("why are apples not oranges?"), + }), + }), + run: async (payload, io, ctx) => { + const prediction = await io.replicate.run("await-prediction", { + identifier: + "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", + input: { + prompt: payload.prompt, + system_prompt: "Answer like John Cleese. Don't be funny.", + max_new_tokens: 200, + }, + }); + + return Array.isArray(prediction.output) ? prediction.output.join("") : prediction.output; + }, +}); + client.defineJob({ id: "replicate-cinematic-prompt", name: "Replicate - Cinematic Prompt", From 3d625fe38128ed6c0c869507b3ef607e59905c0c Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Fri, 29 Sep 2023 08:02:58 +0000 Subject: [PATCH 23/37] Changeset --- .changeset/fair-plums-grin.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 .changeset/fair-plums-grin.md diff --git a/.changeset/fair-plums-grin.md b/.changeset/fair-plums-grin.md new file mode 100644 index 0000000000..31579ad8c2 --- /dev/null +++ b/.changeset/fair-plums-grin.md @@ -0,0 +1,13 @@ +--- +"@trigger.dev/replicate": patch +"@trigger.dev/airtable": patch +"@trigger.dev/sendgrid": patch +"@trigger.dev/sdk": patch +"@trigger.dev/github": patch +"@trigger.dev/linear": patch +"@trigger.dev/resend": patch +"@trigger.dev/slack": patch +"@trigger.dev/core": patch +--- + +First release of `@trigger.dev/replicate` integration with remote callback support. From f0f438cd369180a0e15d0c3d01ac9b0f64d65e3a Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Fri, 29 Sep 2023 08:03:52 +0000 Subject: [PATCH 24/37] Version bump --- integrations/replicate/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/replicate/package.json b/integrations/replicate/package.json index e66eb37f18..65327cea14 100644 --- a/integrations/replicate/package.json +++ b/integrations/replicate/package.json @@ -1,6 +1,6 @@ { "name": "@trigger.dev/replicate", - "version": "2.1.5", + "version": "2.1.6", "description": "Trigger.dev integration for replicate", "main": "./dist/index.js", "types": "./dist/index.d.ts", From a87ebc975111b53f3616fbaf99204ecbe547ee0f Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Fri, 29 Sep 2023 08:50:49 +0000 Subject: [PATCH 25/37] Roll back ugly types --- integrations/replicate/src/deployments.ts | 2 +- integrations/replicate/src/index.ts | 10 +++------- integrations/replicate/src/predictions.ts | 2 +- integrations/replicate/src/trainings.ts | 2 +- packages/core/src/schemas/api.ts | 2 -- packages/trigger-sdk/src/io.ts | 6 ++---- 6 files changed, 8 insertions(+), 16 deletions(-) diff --git a/integrations/replicate/src/deployments.ts b/integrations/replicate/src/deployments.ts index e6f132f22c..9f4f18d55e 100644 --- a/integrations/replicate/src/deployments.ts +++ b/integrations/replicate/src/deployments.ts @@ -55,7 +55,7 @@ class Predictions { return client.deployments.predictions.create(deployment_owner, deployment_name, { ...options, - webhook: task.callbackUrl, + webhook: task.callbackUrl ?? "", webhook_events_filter: ["completed"], }); }, diff --git a/integrations/replicate/src/index.ts b/integrations/replicate/src/index.ts index 7cbd9c2e6f..b22ac52c55 100644 --- a/integrations/replicate/src/index.ts +++ b/integrations/replicate/src/index.ts @@ -66,14 +66,10 @@ export class Replicate implements TriggerIntegration { }); } - runTask | void, TOptions extends RunTaskOptions>( + runTask | void>( key: IntegrationTaskKey, - callback: ( - client: ReplicateClient, - task: TOptions extends RunTaskOptionsWithCallback ? IOTaskWithCallback : IOTask, - io: IO - ) => Promise, - options?: TOptions, + callback: (client: ReplicateClient, task: IOTask, io: IO) => Promise, + options?: RunTaskOptions, errorCallback?: RunTaskErrorCallback ): Promise { if (!this._io) throw new Error("No IO"); diff --git a/integrations/replicate/src/predictions.ts b/integrations/replicate/src/predictions.ts index 44bdcc9eb2..d231bb47ca 100644 --- a/integrations/replicate/src/predictions.ts +++ b/integrations/replicate/src/predictions.ts @@ -51,7 +51,7 @@ export class Predictions { (client, task) => { return client.predictions.create({ ...params, - webhook: task.callbackUrl, + webhook: task.callbackUrl ?? "", webhook_events_filter: ["completed"], }); }, diff --git a/integrations/replicate/src/trainings.ts b/integrations/replicate/src/trainings.ts index a1e0b3002d..fed2f3c6b8 100644 --- a/integrations/replicate/src/trainings.ts +++ b/integrations/replicate/src/trainings.ts @@ -63,7 +63,7 @@ export class Trainings { return client.trainings.create(model_owner, model_name, version_id, { ...options, - webhook: task.callbackUrl, + webhook: task.callbackUrl ?? "", webhook_events_filter: ["completed"], }); }, diff --git a/packages/core/src/schemas/api.ts b/packages/core/src/schemas/api.ts index b524a27481..607dee5d50 100644 --- a/packages/core/src/schemas/api.ts +++ b/packages/core/src/schemas/api.ts @@ -630,8 +630,6 @@ export const RunTaskOptionsSchema = z.object({ export type RunTaskOptions = z.input; -export type RunTaskOptionsWithCallback = RunTaskOptions & { callback: { enabled: true } }; - export type OverridableRunTaskOptions = Pick< RunTaskOptions, "retry" | "delayUntil" | "description" diff --git a/packages/trigger-sdk/src/io.ts b/packages/trigger-sdk/src/io.ts index 8c9ea5d7c7..fa5a8afef3 100644 --- a/packages/trigger-sdk/src/io.ts +++ b/packages/trigger-sdk/src/io.ts @@ -36,8 +36,6 @@ import { TriggerStatus } from "./status"; export type IOTask = ServerTask; -export type IOTaskWithCallback = IOTask & { callbackUrl: string }; - export type IOOptions = { id: string; apiClient: ApiClient; @@ -530,7 +528,7 @@ export class IO { */ async runTask | void>( key: string | any[], - callback: (task: ServerTask & { callbackUrl: string }, io: IO) => Promise, + callback: (task: ServerTask, io: IO) => Promise, options?: RunTaskOptions, onError?: RunTaskErrorCallback ): Promise { @@ -596,7 +594,7 @@ export class IO { const executeTask = async () => { try { - const result = await callback(task as ServerTask & { callbackUrl: string }, this); + const result = await callback(task, this); const output = SerializableJsonSchema.parse(result) as T; From 57ddd812c71c0b470b48a98356b8514ca648be19 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Fri, 29 Sep 2023 09:42:06 +0000 Subject: [PATCH 26/37] Remove missing types --- integrations/replicate/src/index.ts | 2 -- packages/trigger-sdk/src/types.ts | 2 -- 2 files changed, 4 deletions(-) diff --git a/integrations/replicate/src/index.ts b/integrations/replicate/src/index.ts index b22ac52c55..1c7220f10b 100644 --- a/integrations/replicate/src/index.ts +++ b/integrations/replicate/src/index.ts @@ -1,7 +1,6 @@ import { TriggerIntegration, RunTaskOptions, - RunTaskOptionsWithCallback, IO, IOTask, IntegrationTaskKey, @@ -9,7 +8,6 @@ import { Json, retry, ConnectionAuth, - IOTaskWithCallback, } from "@trigger.dev/sdk"; import ReplicateClient, { Page, Prediction } from "replicate"; diff --git a/packages/trigger-sdk/src/types.ts b/packages/trigger-sdk/src/types.ts index da5ad78f96..e7938b236e 100644 --- a/packages/trigger-sdk/src/types.ts +++ b/packages/trigger-sdk/src/types.ts @@ -7,7 +7,6 @@ import type { RedactString, RegisteredOptionsDiff, RunTaskOptions, - RunTaskOptionsWithCallback, RuntimeEnvironmentType, SourceEventOption, TriggerMetadata, @@ -23,7 +22,6 @@ export type { RedactString, RegisteredOptionsDiff, RunTaskOptions, - RunTaskOptionsWithCallback, SourceEventOption, }; From 918d965dd7a97ef55d0377f379b34256dcf05805 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Fri, 29 Sep 2023 10:12:43 +0000 Subject: [PATCH 27/37] Quicker return when waiting on remote callback --- packages/trigger-sdk/src/io.ts | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/packages/trigger-sdk/src/io.ts b/packages/trigger-sdk/src/io.ts index fa5a8afef3..0135d1c977 100644 --- a/packages/trigger-sdk/src/io.ts +++ b/packages/trigger-sdk/src/io.ts @@ -596,6 +596,14 @@ export class IO { try { const result = await callback(task, this); + if (task.status === "WAITING" && task.callbackUrl) { + this._logger.debug("Waiting for remote callback", { + idempotencyKey, + task, + }); + return {} as T; + } + const output = SerializableJsonSchema.parse(result) as T; this._logger.debug("Completing using output", { @@ -603,11 +611,6 @@ export class IO { task, }); - // TODO: empty return? maybe don't even parse first - if (task.status === "WAITING" && task.callbackUrl) { - return output; - } - const completedTask = await this._apiClient.completeTask(this._id, task.id, { output: output ?? undefined, properties: task.outputProperties ?? undefined, From 756649adcddda43b96a6c12151b4b15e99b80929 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Fri, 29 Sep 2023 13:58:14 +0000 Subject: [PATCH 28/37] Remote callback example --- docs/sdk/io/runtask.mdx | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/docs/sdk/io/runtask.mdx b/docs/sdk/io/runtask.mdx index 2eb9cea54d..4572a5b87e 100644 --- a/docs/sdk/io/runtask.mdx +++ b/docs/sdk/io/runtask.mdx @@ -119,7 +119,7 @@ A Task is a resumable unit of a Run that can be retried, resumed and is logged. Enabling this feature will expose a `callbackUrl` property on the callback's Task parameter. Additionally, `io.runTask()` will now return a Promise that resolves with the body of the first request sent to that URL. - + Whether to enable the remote callback feature. @@ -219,4 +219,43 @@ client.defineJob({ }); ``` +```typescript Remote callbacks +client.defineJob({ + id: "remote-callback-example", + name: "Remote Callback example", + version: "0.1.1", + trigger: eventTrigger({ name: "predict" }), + integrations: { replicate }, + run: async (payload, io, ctx) => { + //runTask + const prediction = await io.runTask( + "create-and-await-prediction", + async (client, task) => { + //create a prediction using the underlying Replicate Integration client + await client.predictions.create({ + ...payload, + webhook: task.callbackUrl ?? "", + webhook_events_filter: ["completed"], + }); + //the actual return value will be the data sent to callbackUrl + //cast to the exact data type you expect to receive or `any` if unsure + return {} as Prediction; + }, + { + name: "Create and await Prediction", + icon: "replicate", + //remote callback settings + callback: { + enabled: true, + timeoutInSeconds: 300, + }, + } + ); + + //log the prediction output + await io.logger.info(prediction.output); + }, +}); +``` + From ddf1a80aa35977db519dd2f74a8d3bb6c78cca3a Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Fri, 29 Sep 2023 18:36:53 +0000 Subject: [PATCH 29/37] Bump version --- integrations/replicate/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/replicate/package.json b/integrations/replicate/package.json index 65327cea14..9478a0a36e 100644 --- a/integrations/replicate/package.json +++ b/integrations/replicate/package.json @@ -1,6 +1,6 @@ { "name": "@trigger.dev/replicate", - "version": "2.1.6", + "version": "2.1.7", "description": "Trigger.dev integration for replicate", "main": "./dist/index.js", "types": "./dist/index.d.ts", From 162bfc40bdfc8d2d6b3b3843ce5f421af801399a Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Mon, 2 Oct 2023 07:35:33 +0000 Subject: [PATCH 30/37] Remove schema parsing --- ....runs.$runId.tasks.$id.callback.$secret.ts | 26 +++++-------------- packages/core/src/schemas/api.ts | 5 ---- 2 files changed, 6 insertions(+), 25 deletions(-) diff --git a/apps/webapp/app/routes/api.v1.runs.$runId.tasks.$id.callback.$secret.ts b/apps/webapp/app/routes/api.v1.runs.$runId.tasks.$id.callback.$secret.ts index 7402c3a758..5c1425e2a0 100644 --- a/apps/webapp/app/routes/api.v1.runs.$runId.tasks.$id.callback.$secret.ts +++ b/apps/webapp/app/routes/api.v1.runs.$runId.tasks.$id.callback.$secret.ts @@ -1,7 +1,5 @@ import type { ActionArgs } from "@remix-run/server-runtime"; import { json } from "@remix-run/server-runtime"; -import type { CallbackTaskBodyOutput } from "@trigger.dev/core"; -import { CallbackTaskBodyInputSchema } from "@trigger.dev/core"; import { RuntimeEnvironmentType } from "@trigger.dev/database"; import { z } from "zod"; import { $transaction, PrismaClient, PrismaClientOrTransaction, prisma } from "~/db.server"; @@ -22,21 +20,14 @@ export async function action({ request, params }: ActionArgs) { const { runId, id } = ParamsSchema.parse(params); - // Now parse the request body - const anyBody = await request.json(); - - // Allows any valid object - // TODO: maybe add proper schema parsing during io.runTask(), or even skip this step - const body = CallbackTaskBodyInputSchema.safeParse(anyBody); - - if (!body.success) { - return json({ error: "Invalid request body" }, { status: 400 }); - } + // Parse body as JSON (no schema parsing) + const body = await request.json(); const service = new CallbackRunTaskService(); try { - await service.call(runId, id, body.data, request.url); + // Complete task with request body as output + await service.call(runId, id, body, request.url); return json({ success: true }); } catch (error) { @@ -55,12 +46,7 @@ export class CallbackRunTaskService { this.#prismaClient = prismaClient; } - public async call( - runId: string, - id: string, - taskBody: CallbackTaskBodyOutput, - callbackUrl: string - ): Promise { + public async call(runId: string, id: string, taskBody: any, callbackUrl: string): Promise { const task = await findTask(prisma, id); if (!task) { @@ -89,7 +75,7 @@ export class CallbackRunTaskService { await this.#resumeTask(task, taskBody); } - async #resumeTask(task: NonNullable, output: Record) { + async #resumeTask(task: NonNullable, output: any) { await $transaction(this.#prismaClient, async (tx) => { await tx.taskAttempt.updateMany({ where: { diff --git a/packages/core/src/schemas/api.ts b/packages/core/src/schemas/api.ts index 607dee5d50..041126bcdc 100644 --- a/packages/core/src/schemas/api.ts +++ b/packages/core/src/schemas/api.ts @@ -654,11 +654,6 @@ export const RunTaskBodyOutputSchema = RunTaskBodyInputSchema.extend({ export type RunTaskBodyOutput = z.infer; -export const CallbackTaskBodyInputSchema = z.object({}).passthrough(); - -export type CallbackTaskBodyInput = Prettify>; -export type CallbackTaskBodyOutput = z.infer; - export const CompleteTaskBodyInputSchema = RunTaskBodyInputSchema.pick({ properties: true, description: true, From bf85d37af03f61e2d38b00a8a7c3627cc095feb6 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Tue, 3 Oct 2023 20:45:56 +0000 Subject: [PATCH 31/37] Only schedule positive callback timeout --- apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts b/apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts index 1dfcf164e7..946510b024 100644 --- a/apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts +++ b/apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts @@ -229,6 +229,7 @@ export class RunTaskService { { tx, runAt: task.delayUntil ?? undefined } ); } else if (task.status === "WAITING" && callbackUrl && taskBody.callback) { + if (taskBody.callback.timeoutInSeconds > 0) { // We need to schedule the callback timeout await workerQueue.enqueue( "processCallbackTimeout", @@ -237,6 +238,7 @@ export class RunTaskService { }, { tx, runAt: new Date(Date.now() + taskBody.callback.timeoutInSeconds * 1000) } ); + } } return task; From cd3147acc724d27e2cf567f6ead48f11d3526623 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Tue, 3 Oct 2023 20:47:15 +0000 Subject: [PATCH 32/37] Decrease callback secret length --- .../app/routes/api.v1.runs.$runId.tasks.ts | 18 +++++++++--------- .../app/services/sources/utils.server.ts | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts b/apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts index 946510b024..c0f80cd8bb 100644 --- a/apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts +++ b/apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts @@ -168,7 +168,7 @@ export class RunTaskService { const taskId = ulid(); const callbackUrl = callbackEnabled - ? `${env.APP_ORIGIN}/api/v1/runs/${runId}/tasks/${taskId}/callback/${generateSecret()}` + ? `${env.APP_ORIGIN}/api/v1/runs/${runId}/tasks/${taskId}/callback/${generateSecret(12)}` : undefined; const task = await tx.task.create({ @@ -230,14 +230,14 @@ export class RunTaskService { ); } else if (task.status === "WAITING" && callbackUrl && taskBody.callback) { if (taskBody.callback.timeoutInSeconds > 0) { - // We need to schedule the callback timeout - await workerQueue.enqueue( - "processCallbackTimeout", - { - id: task.id, - }, - { tx, runAt: new Date(Date.now() + taskBody.callback.timeoutInSeconds * 1000) } - ); + // We need to schedule the callback timeout + await workerQueue.enqueue( + "processCallbackTimeout", + { + id: task.id, + }, + { tx, runAt: new Date(Date.now() + taskBody.callback.timeoutInSeconds * 1000) } + ); } } diff --git a/apps/webapp/app/services/sources/utils.server.ts b/apps/webapp/app/services/sources/utils.server.ts index 127ca4b7ab..4c2bc7ae7b 100644 --- a/apps/webapp/app/services/sources/utils.server.ts +++ b/apps/webapp/app/services/sources/utils.server.ts @@ -1,5 +1,5 @@ import crypto from "node:crypto"; -export function generateSecret(): string { - return crypto.randomBytes(32).toString("hex"); +export function generateSecret(sizeInBytes = 32): string { + return crypto.randomBytes(sizeInBytes).toString("hex"); } From 340b8ad1668abb10d6a33a2b483bbb38fe269386 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Tue, 3 Oct 2023 21:37:05 +0000 Subject: [PATCH 33/37] Explicit default timeouts --- integrations/replicate/src/deployments.ts | 12 ++++++++---- integrations/replicate/src/index.ts | 4 ++-- integrations/replicate/src/predictions.ts | 12 ++++++++---- integrations/replicate/src/trainings.ts | 12 ++++++++---- integrations/replicate/src/types.ts | 2 ++ integrations/replicate/src/utils.ts | 6 ++++-- 6 files changed, 32 insertions(+), 16 deletions(-) diff --git a/integrations/replicate/src/deployments.ts b/integrations/replicate/src/deployments.ts index 9f4f18d55e..af59baddfb 100644 --- a/integrations/replicate/src/deployments.ts +++ b/integrations/replicate/src/deployments.ts @@ -3,7 +3,7 @@ import ReplicateClient, { Prediction } from "replicate"; import { ReplicateRunTask } from "./index"; import { callbackProperties, createDeploymentProperties } from "./utils"; -import { ReplicateReturnType } from "./types"; +import { CallbackTimeout, ReplicateReturnType } from "./types"; export class Deployments { constructor(private runTask: ReplicateRunTask) {} @@ -46,7 +46,8 @@ class Predictions { } & Omit< Parameters[2], "webhook" | "webhook_events_filter" - > & { timeoutInSeconds?: number } + >, + options: CallbackTimeout = { timeoutInSeconds: 3600 } ): ReplicateReturnType { return this.runTask( key, @@ -62,8 +63,11 @@ class Predictions { { name: "Create And Await Prediction With Deployment", params, - properties: [...createDeploymentProperties(params), ...callbackProperties(params)], - callback: { enabled: true }, + properties: [...createDeploymentProperties(params), ...callbackProperties(options)], + callback: { + enabled: true, + timeoutInSeconds: options.timeoutInSeconds, + }, } ); } diff --git a/integrations/replicate/src/index.ts b/integrations/replicate/src/index.ts index 1c7220f10b..b514ebe0a1 100644 --- a/integrations/replicate/src/index.ts +++ b/integrations/replicate/src/index.ts @@ -177,7 +177,7 @@ export class Replicate implements TriggerIntegration { "webhook" | "webhook_events_filter" | "wait" | "signal" > ): ReplicateReturnType { - const { identifier, ...options } = params; + const { identifier, ...paramsWithoutIdentifier } = params; // see: https://github.com/replicate/replicate-javascript/blob/4b0d9cb0e226fab3d3d31de5b32261485acf5626/index.js#L102 @@ -194,7 +194,7 @@ export class Replicate implements TriggerIntegration { const { version } = match.groups; - return this.predictions.createAndAwait(key, { ...options, version }); + return this.predictions.createAndAwait(key, { ...paramsWithoutIdentifier, version }); } // TODO: wait(prediction) - needs polling diff --git a/integrations/replicate/src/predictions.ts b/integrations/replicate/src/predictions.ts index d231bb47ca..e00a6c8d16 100644 --- a/integrations/replicate/src/predictions.ts +++ b/integrations/replicate/src/predictions.ts @@ -2,7 +2,7 @@ import { IntegrationTaskKey } from "@trigger.dev/sdk"; import ReplicateClient, { Page, Prediction } from "replicate"; import { ReplicateRunTask } from "./index"; -import { ReplicateReturnType } from "./types"; +import { CallbackTimeout, ReplicateReturnType } from "./types"; import { callbackProperties, createPredictionProperties } from "./utils"; export class Predictions { @@ -44,7 +44,8 @@ export class Predictions { params: Omit< Parameters[0], "webhook" | "webhook_events_filter" - > & { timeoutInSeconds?: number } + >, + options: CallbackTimeout = { timeoutInSeconds: 3600 } ): ReplicateReturnType { return this.runTask( key, @@ -58,8 +59,11 @@ export class Predictions { { name: "Create And Await Prediction", params, - properties: [...createPredictionProperties(params), ...callbackProperties(params)], - callback: { enabled: true }, + properties: [...createPredictionProperties(params), ...callbackProperties(options)], + callback: { + enabled: true, + timeoutInSeconds: options.timeoutInSeconds, + }, } ); } diff --git a/integrations/replicate/src/trainings.ts b/integrations/replicate/src/trainings.ts index fed2f3c6b8..08cfa3204f 100644 --- a/integrations/replicate/src/trainings.ts +++ b/integrations/replicate/src/trainings.ts @@ -2,7 +2,7 @@ import { IntegrationTaskKey } from "@trigger.dev/sdk"; import ReplicateClient, { Page, Training } from "replicate"; import { ReplicateRunTask } from "./index"; -import { ReplicateReturnType } from "./types"; +import { CallbackTimeout, ReplicateReturnType } from "./types"; import { callbackProperties, modelProperties } from "./utils"; export class Trainings { @@ -54,7 +54,8 @@ export class Trainings { } & Omit< Parameters[3], "webhook" | "webhook_events_filter" - > & { timeoutInSeconds?: number } + >, + options: CallbackTimeout = { timeoutInSeconds: 3600 } ): ReplicateReturnType { return this.runTask( key, @@ -70,8 +71,11 @@ export class Trainings { { name: "Create And Await Training", params, - properties: [...modelProperties(params), ...callbackProperties(params)], - callback: { enabled: true }, + properties: [...modelProperties(params), ...callbackProperties(options)], + callback: { + enabled: true, + timeoutInSeconds: options.timeoutInSeconds, + }, } ); } diff --git a/integrations/replicate/src/types.ts b/integrations/replicate/src/types.ts index 9863160718..d8fafcd1dc 100644 --- a/integrations/replicate/src/types.ts +++ b/integrations/replicate/src/types.ts @@ -1 +1,3 @@ +export type CallbackTimeout = { timeoutInSeconds?: number }; + export type ReplicateReturnType = Promise; diff --git a/integrations/replicate/src/utils.ts b/integrations/replicate/src/utils.ts index ebdf9eccc8..0a510690fe 100644 --- a/integrations/replicate/src/utils.ts +++ b/integrations/replicate/src/utils.ts @@ -1,3 +1,5 @@ +import { CallbackTimeout } from "./types"; + export const createPredictionProperties = ( params: Partial<{ version: string; @@ -46,11 +48,11 @@ export const streamingProperty = (params: { stream?: boolean }) => { return [{ label: "Streaming Enabled", text: String(!!params.stream) }]; }; -export const callbackProperties = (params: { timeoutInSeconds?: number }) => { +export const callbackProperties = (options: CallbackTimeout) => { return [ { label: "Callback Timeout", - text: params.timeoutInSeconds ? `${params.timeoutInSeconds}s` : "default", + text: options.timeoutInSeconds ? `${options.timeoutInSeconds}s` : "default", }, ]; }; From 20e9ac8152c76acdd5ab427d7db0ee08bd4d0e47 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Tue, 3 Oct 2023 21:45:47 +0000 Subject: [PATCH 34/37] Import deployments tasks --- integrations/replicate/src/index.ts | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/integrations/replicate/src/index.ts b/integrations/replicate/src/index.ts index b514ebe0a1..e46dfc9eb8 100644 --- a/integrations/replicate/src/index.ts +++ b/integrations/replicate/src/index.ts @@ -16,6 +16,7 @@ import { Models } from "./models"; import { Trainings } from "./trainings"; import { Collections } from "./collections"; import { ReplicateReturnType } from "./types"; +import { Deployments } from "./deployments"; export type ReplicateIntegrationOptions = { id: string; @@ -93,6 +94,10 @@ export class Replicate implements TriggerIntegration { return new Collections(this.runTask.bind(this)); } + get deployments() { + return new Deployments(this.runTask.bind(this)); + } + get models() { return new Models(this.runTask.bind(this)); } From 58f55fa2cdffc2526b81d21da05a22da63a9f2d4 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Wed, 4 Oct 2023 06:43:14 +0000 Subject: [PATCH 35/37] JSDoc --- integrations/replicate/src/collections.ts | 2 ++ integrations/replicate/src/deployments.ts | 2 ++ integrations/replicate/src/index.ts | 4 ++++ integrations/replicate/src/models.ts | 3 +++ integrations/replicate/src/predictions.ts | 5 +++++ integrations/replicate/src/trainings.ts | 5 +++++ 6 files changed, 21 insertions(+) diff --git a/integrations/replicate/src/collections.ts b/integrations/replicate/src/collections.ts index 31a94ff557..b6f1c01001 100644 --- a/integrations/replicate/src/collections.ts +++ b/integrations/replicate/src/collections.ts @@ -7,6 +7,7 @@ import { ReplicateReturnType } from "./types"; export class Collections { constructor(private runTask: ReplicateRunTask) {} + /** Fetch a model collection. */ get(key: IntegrationTaskKey, params: { slug: string }): ReplicateReturnType { return this.runTask( key, @@ -21,6 +22,7 @@ export class Collections { ); } + /** Fetch a list of model collections. */ list(key: IntegrationTaskKey): ReplicateReturnType> { return this.runTask( key, diff --git a/integrations/replicate/src/deployments.ts b/integrations/replicate/src/deployments.ts index af59baddfb..c5c1508d12 100644 --- a/integrations/replicate/src/deployments.ts +++ b/integrations/replicate/src/deployments.ts @@ -16,6 +16,7 @@ export class Deployments { class Predictions { constructor(private runTask: ReplicateRunTask) {} + /** Create a new prediction with a deployment. */ create( key: IntegrationTaskKey, params: { @@ -38,6 +39,7 @@ class Predictions { ); } + /** Create a new prediction with a deployment and await the result. */ createAndAwait( key: IntegrationTaskKey, params: { diff --git a/integrations/replicate/src/index.ts b/integrations/replicate/src/index.ts index e46dfc9eb8..0093be164d 100644 --- a/integrations/replicate/src/index.ts +++ b/integrations/replicate/src/index.ts @@ -110,6 +110,7 @@ export class Replicate implements TriggerIntegration { return new Trainings(this.runTask.bind(this)); } + /** Paginate through a list of results. */ async *paginate( task: (key: string) => Promise>, key: IntegrationTaskKey, @@ -134,6 +135,7 @@ export class Replicate implements TriggerIntegration { } } + /** Auto-paginate and return all results. */ async getAll( task: (key: string) => Promise>, key: IntegrationTaskKey @@ -147,6 +149,7 @@ export class Replicate implements TriggerIntegration { return allResults; } + /** Make a request to the Replicate API. */ request( key: IntegrationTaskKey, params: { @@ -173,6 +176,7 @@ export class Replicate implements TriggerIntegration { ); } + /** Run a model and await the result. */ run( key: IntegrationTaskKey, params: { diff --git a/integrations/replicate/src/models.ts b/integrations/replicate/src/models.ts index 4436bf1c8d..d4b3a78ac5 100644 --- a/integrations/replicate/src/models.ts +++ b/integrations/replicate/src/models.ts @@ -8,6 +8,7 @@ import { ReplicateReturnType } from "./types"; export class Models { constructor(private runTask: ReplicateRunTask) {} + /** Get information about a model. */ get( key: IntegrationTaskKey, params: { @@ -36,6 +37,7 @@ export class Models { class Versions { constructor(private runTask: ReplicateRunTask) {} + /** Get a specific model version. */ get( key: IntegrationTaskKey, params: { @@ -57,6 +59,7 @@ class Versions { ); } + /** List model versions. */ list( key: IntegrationTaskKey, params: { diff --git a/integrations/replicate/src/predictions.ts b/integrations/replicate/src/predictions.ts index e00a6c8d16..9f6c604fd2 100644 --- a/integrations/replicate/src/predictions.ts +++ b/integrations/replicate/src/predictions.ts @@ -8,6 +8,7 @@ import { callbackProperties, createPredictionProperties } from "./utils"; export class Predictions { constructor(private runTask: ReplicateRunTask) {} + /** Cancel a prediction. */ cancel(key: IntegrationTaskKey, params: { id: string }): ReplicateReturnType { return this.runTask( key, @@ -22,6 +23,7 @@ export class Predictions { ); } + /** Create a new prediction. */ create( key: IntegrationTaskKey, params: Parameters[0] @@ -39,6 +41,7 @@ export class Predictions { ); } + /** Create a new prediction and await the result. */ createAndAwait( key: IntegrationTaskKey, params: Omit< @@ -68,6 +71,7 @@ export class Predictions { ); } + /** Fetch a prediction. */ get(key: IntegrationTaskKey, params: { id: string }): ReplicateReturnType { return this.runTask( key, @@ -82,6 +86,7 @@ export class Predictions { ); } + /** List all predictions. */ list(key: IntegrationTaskKey): ReplicateReturnType> { return this.runTask( key, diff --git a/integrations/replicate/src/trainings.ts b/integrations/replicate/src/trainings.ts index 08cfa3204f..10a3ae576b 100644 --- a/integrations/replicate/src/trainings.ts +++ b/integrations/replicate/src/trainings.ts @@ -8,6 +8,7 @@ import { callbackProperties, modelProperties } from "./utils"; export class Trainings { constructor(private runTask: ReplicateRunTask) {} + /** Cancel a training. */ cancel(key: IntegrationTaskKey, params: { id: string }): ReplicateReturnType { return this.runTask( key, @@ -22,6 +23,7 @@ export class Trainings { ); } + /** Create a new training. */ create( key: IntegrationTaskKey, params: { @@ -45,6 +47,7 @@ export class Trainings { ); } + /** Create a new training and await the result. */ createAndAwait( key: IntegrationTaskKey, params: { @@ -80,6 +83,7 @@ export class Trainings { ); } + /** Fetch a training. */ get(key: IntegrationTaskKey, params: { id: string }): ReplicateReturnType { return this.runTask( key, @@ -94,6 +98,7 @@ export class Trainings { ); } + /** List all trainings. */ list(key: IntegrationTaskKey): ReplicateReturnType> { return this.runTask( key, From e5dce845f0b626e5b55f3b3fdb974a21b0d2ca60 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Wed, 4 Oct 2023 06:46:53 +0000 Subject: [PATCH 36/37] Deployments docs --- docs/integrations/apis/replicate.mdx | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/integrations/apis/replicate.mdx b/docs/integrations/apis/replicate.mdx index ef4dd7958e..978bb51c19 100644 --- a/docs/integrations/apis/replicate.mdx +++ b/docs/integrations/apis/replicate.mdx @@ -125,6 +125,13 @@ client.defineJob({ | `collections.get` | Gets a collection. | | `collections.list` | Returns the first page of all collections. Use with pagination helper. | +### Deployments + +| Function Name | Description | +| ---------------------------------------- | --------------------------------------------------------- | +| `deployments.predictions.create` | Creates a new prediction with a deployment. | +| `deployments.predictions.createAndAwait` | Creates and waits for a new prediction with a deployment. | + ### Models | Function Name | Description | From c8868823d8e9c5f644197b8aa66b7fc4fff81ff3 Mon Sep 17 00:00:00 2001 From: nicktrn <55853254+nicktrn@users.noreply.github.com> Date: Wed, 4 Oct 2023 07:25:50 +0000 Subject: [PATCH 37/37] Fix runTask examples, mention wrappers --- docs/sdk/io/runtask.mdx | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/sdk/io/runtask.mdx b/docs/sdk/io/runtask.mdx index 4572a5b87e..46a8d829d9 100644 --- a/docs/sdk/io/runtask.mdx +++ b/docs/sdk/io/runtask.mdx @@ -6,6 +6,8 @@ description: "`io.runTask()` allows you to run a [Task](/documentation/concepts/ A Task is a resumable unit of a Run that can be retried, resumed and is logged. [Integrations](/integrations) use Tasks internally to perform their actions. +The wrappers at `io.integration.runTask()` expose the underlying Integration client as the first callback parameter (see examples on the right). They will have defaults set for options and `onError` handlers, but should otherwise be considered identical to raw `io.runTask()`. + ## Parameters @@ -168,11 +170,11 @@ client.defineJob({ }, run: async (payload, io, ctx) => { //runTask - const response = await io.runTask( + const response = await io.github.runTask( "create-card", - async () => { + async (client) => { //create a project card using the underlying GitHub Integration client - return io.github.client.rest.projects.createCard({ + return client.rest.projects.createCard({ column_id: 123, note: "test", }); @@ -228,7 +230,7 @@ client.defineJob({ integrations: { replicate }, run: async (payload, io, ctx) => { //runTask - const prediction = await io.runTask( + const prediction = await io.replicate.runTask( "create-and-await-prediction", async (client, task) => { //create a prediction using the underlying Replicate Integration client