Skip to content

Replicate integration and remote callbacks #507

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 45 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
93ae955
Support tasks with remote callbacks
nicktrn Sep 26, 2023
ff0863e
Add common integration tsconfig
nicktrn Sep 26, 2023
f998afc
Add Replicate integration
nicktrn Sep 26, 2023
a196a86
Basic job catalog example
nicktrn Sep 26, 2023
0de1147
Integration catalog entry
nicktrn Sep 26, 2023
6f85972
Check for callbackUrl during executeTask
nicktrn Sep 26, 2023
0ab9dc1
Fix getAll
nicktrn Sep 26, 2023
3e9e73e
Improve JSDoc
nicktrn Sep 26, 2023
f6e5eb1
Merge branch 'main' into integrations/replicate
nicktrn Sep 26, 2023
cbe3170
Bump version
nicktrn Sep 26, 2023
47ec5e6
Remove named queue
nicktrn Sep 26, 2023
4f7272a
Merge branch 'main' into integrations/replicate
nicktrn Sep 27, 2023
ba94516
Merge branch 'main' into integrations/replicate
nicktrn Sep 28, 2023
4749e10
Simplify runTask types
nicktrn Sep 28, 2023
ded84db
Trust the types
nicktrn Sep 28, 2023
b0a1fdd
Fail tasks on timeout
nicktrn Sep 28, 2023
a53debb
Callback timeout as param
nicktrn Sep 28, 2023
d80b13a
Mess with types
nicktrn Sep 28, 2023
3f278e8
performRunExecutionV1
nicktrn Sep 28, 2023
043ae42
Update runTask docs
nicktrn Sep 28, 2023
259bd7c
Shorten callback task methods
nicktrn Sep 28, 2023
04ae906
Fix run method return type
nicktrn Sep 29, 2023
6f739ca
Image processing jobs
nicktrn Sep 29, 2023
49fdcd0
Replicate docs
nicktrn Sep 29, 2023
9054750
Text output example
nicktrn Sep 29, 2023
3d625fe
Changeset
nicktrn Sep 29, 2023
f0f438c
Version bump
nicktrn Sep 29, 2023
a87ebc9
Roll back ugly types
nicktrn Sep 29, 2023
4d119bf
Merge branch 'main' into integrations/replicate
nicktrn Sep 29, 2023
57ddd81
Remove missing types
nicktrn Sep 29, 2023
918d965
Quicker return when waiting on remote callback
nicktrn Sep 29, 2023
756649a
Remote callback example
nicktrn Sep 29, 2023
4844e73
Merge branch 'main' into integrations/replicate
nicktrn Sep 29, 2023
ddf1a80
Bump version
nicktrn Sep 29, 2023
90b4f05
Merge branch 'main' into integrations/replicate
nicktrn Oct 2, 2023
162bfc4
Remove schema parsing
nicktrn Oct 2, 2023
5dd050e
Merge branch 'main' into integrations/replicate
ericallam Oct 3, 2023
bf85d37
Only schedule positive callback timeout
nicktrn Oct 3, 2023
cd3147a
Decrease callback secret length
nicktrn Oct 3, 2023
340b8ad
Explicit default timeouts
nicktrn Oct 3, 2023
20e9ac8
Import deployments tasks
nicktrn Oct 3, 2023
58f55fa
JSDoc
nicktrn Oct 4, 2023
e5dce84
Deployments docs
nicktrn Oct 4, 2023
c886882
Fix runTask examples, mention wrappers
nicktrn Oct 4, 2023
f5a68f1
Merge branch 'replicate-cleanup' into integrations/replicate
nicktrn Oct 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .changeset/fair-plums-grin.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
"@trigger.dev/replicate": patch
"@trigger.dev/airtable": patch
"@trigger.dev/sendgrid": patch
"@trigger.dev/sdk": patch
"@trigger.dev/github": patch
"@trigger.dev/linear": patch
"@trigger.dev/resend": patch
"@trigger.dev/slack": patch
"@trigger.dev/core": patch
---

First release of `@trigger.dev/replicate` integration with remote callback support.
1 change: 1 addition & 0 deletions apps/webapp/app/models/task.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export function taskWithAttemptsToServerTask(task: TaskWithAttempts): ServerTask
attempts: task.attempts.length,
idempotencyKey: task.idempotencyKey,
operation: task.operation,
callbackUrl: task.callbackUrl,
};
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import type { ActionArgs } from "@remix-run/server-runtime";
import { json } from "@remix-run/server-runtime";
import { RuntimeEnvironmentType } from "@trigger.dev/database";
import { z } from "zod";
import { $transaction, PrismaClient, PrismaClientOrTransaction, prisma } from "~/db.server";
import { enqueueRunExecutionV2 } from "~/models/jobRunExecution.server";
import { logger } from "~/services/logger.server";

const ParamsSchema = z.object({
runId: z.string(),
id: z.string(),
secret: z.string(),
});

export async function action({ request, params }: ActionArgs) {
// Ensure this is a POST request
if (request.method.toUpperCase() !== "POST") {
return { status: 405, body: "Method Not Allowed" };
}

const { runId, id } = ParamsSchema.parse(params);

// Parse body as JSON (no schema parsing)
const body = await request.json();

const service = new CallbackRunTaskService();

try {
// Complete task with request body as output
await service.call(runId, id, body, request.url);

return json({ success: true });
} catch (error) {
if (error instanceof Error) {
logger.error("Error while processing task callback:", { error });
}

return json({ error: "Something went wrong" }, { status: 500 });
}
}

export class CallbackRunTaskService {
#prismaClient: PrismaClient;

constructor(prismaClient: PrismaClient = prisma) {
this.#prismaClient = prismaClient;
}

public async call(runId: string, id: string, taskBody: any, callbackUrl: string): Promise<void> {
const task = await findTask(prisma, id);

if (!task) {
return;
}

if (task.runId !== runId) {
return;
}

if (task.status !== "WAITING") {
return;
}

if (!task.callbackUrl) {
return;
}

if (new URL(task.callbackUrl).pathname !== new URL(callbackUrl).pathname) {
logger.error("Callback URLs don't match", { runId, taskId: id, callbackUrl });
return;
}

logger.debug("CallbackRunTaskService.call()", { task });

await this.#resumeTask(task, taskBody);
}

async #resumeTask(task: NonNullable<FoundTask>, output: any) {
await $transaction(this.#prismaClient, async (tx) => {
await tx.taskAttempt.updateMany({
where: {
taskId: task.id,
status: "PENDING",
},
data: {
status: "COMPLETED",
},
});

await tx.task.update({
where: { id: task.id },
data: {
status: "COMPLETED",
completedAt: new Date(),
output: output ? output : undefined,
},
});

await this.#resumeRunExecution(task, tx);
});
}

async #resumeRunExecution(task: NonNullable<FoundTask>, prisma: PrismaClientOrTransaction) {
await enqueueRunExecutionV2(task.run, prisma, {
skipRetrying: task.run.environment.type === RuntimeEnvironmentType.DEVELOPMENT,
});
}
}

type FoundTask = Awaited<ReturnType<typeof findTask>>;

async function findTask(prisma: PrismaClientOrTransaction, id: string) {
return prisma.task.findUnique({
where: { id },
include: {
run: {
include: {
environment: true,
queue: true,
},
},
},
});
}
28 changes: 25 additions & 3 deletions apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import { authenticateApiRequest } from "~/services/apiAuth.server";
import { logger } from "~/services/logger.server";
import { ulid } from "~/services/ulid.server";
import { workerQueue } from "~/services/worker.server";
import { generateSecret } from "~/services/sources/utils.server";
import { env } from "~/env.server";

const ParamsSchema = z.object({
runId: z.string(),
Expand Down Expand Up @@ -185,10 +187,13 @@ export class RunTaskService {
},
});

const delayUntilInFuture = taskBody.delayUntil && taskBody.delayUntil.getTime() > Date.now();
const callbackEnabled = taskBody.callback?.enabled;

if (existingTask) {
if (existingTask.status === "CANCELED") {
const existingTaskStatus =
(taskBody.delayUntil && taskBody.delayUntil.getTime() > Date.now()) || taskBody.trigger
delayUntilInFuture || callbackEnabled || taskBody.trigger
? "WAITING"
: taskBody.noop
? "COMPLETED"
Expand Down Expand Up @@ -233,16 +238,21 @@ export class RunTaskService {
status = "CANCELED";
} else {
status =
(taskBody.delayUntil && taskBody.delayUntil.getTime() > Date.now()) || taskBody.trigger
delayUntilInFuture || callbackEnabled || taskBody.trigger
? "WAITING"
: taskBody.noop
? "COMPLETED"
: "RUNNING";
}

const taskId = ulid();
const callbackUrl = callbackEnabled
? `${env.APP_ORIGIN}/api/v1/runs/${runId}/tasks/${taskId}/callback/${generateSecret(12)}`
: undefined;

const task = await tx.task.create({
data: {
id: ulid(),
id: taskId,
idempotencyKey,
displayKey: taskBody.displayKey,
runConnection: taskBody.connectionKey
Expand Down Expand Up @@ -273,6 +283,7 @@ export class RunTaskService {
properties: taskBody.properties ?? undefined,
redact: taskBody.redact ?? undefined,
operation: taskBody.operation,
callbackUrl,
style: taskBody.style ?? { style: "normal" },
attempts: {
create: {
Expand All @@ -296,6 +307,17 @@ export class RunTaskService {
},
{ tx, runAt: task.delayUntil ?? undefined }
);
} else if (task.status === "WAITING" && callbackUrl && taskBody.callback) {
if (taskBody.callback.timeoutInSeconds > 0) {
// We need to schedule the callback timeout
await workerQueue.enqueue(
"processCallbackTimeout",
{
id: task.id,
},
{ tx, runAt: new Date(Date.now() + taskBody.callback.timeoutInSeconds * 1000) }
);
}
}

return task;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { github } from "./integrations/github";
import { linear } from "./integrations/linear";
import { openai } from "./integrations/openai";
import { plain } from "./integrations/plain";
import { replicate } from "./integrations/replicate";
import { resend } from "./integrations/resend";
import { sendgrid } from "./integrations/sendgrid";
import { slack } from "./integrations/slack";
Expand Down Expand Up @@ -37,6 +38,7 @@ export const integrationCatalog = new IntegrationCatalog({
linear,
openai,
plain,
replicate,
resend,
slack,
stripe,
Expand Down
50 changes: 50 additions & 0 deletions apps/webapp/app/services/externalApis/integrations/replicate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import type { HelpSample, Integration } from "../types";

function usageSample(hasApiKey: boolean): HelpSample {
const apiKeyPropertyName = "apiKey";

return {
title: "Using the client",
code: `
import { Replicate } from "@trigger.dev/replicate";

const replicate = new Replicate({
id: "__SLUG__",${hasApiKey ? `,\n ${apiKeyPropertyName}: process.env.REPLICATE_API_KEY!` : ""}
});

client.defineJob({
id: "replicate-create-prediction",
name: "Replicate - Create Prediction",
version: "0.1.0",
integrations: { replicate },
trigger: eventTrigger({
name: "replicate.predict",
schema: z.object({
prompt: z.string(),
version: z.string(),
}),
}),
run: async (payload, io, ctx) => {
return io.replicate.predictions.createAndAwait("await-prediction", {
version: payload.version,
input: { prompt: payload.prompt },
});
},
});
`,
};
}

export const replicate: Integration = {
identifier: "replicate",
name: "Replicate",
packageName: "@trigger.dev/replicate@latest",
authenticationMethods: {
apikey: {
type: "apikey",
help: {
samples: [usageSample(true)],
},
},
},
};
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,9 @@ export class PerformRunExecutionV1Service {

// If the task has an operation, then the next performRunExecution will occur
// when that operation has finished
if (!data.task.operation) {
// Tasks with callbacks enabled will also get processed separately, i.e. when
// they time out, or on valid requests to their callbackUrl
if (!data.task.operation && !data.task.callbackUrl) {
const newJobExecution = await tx.jobRunExecution.create({
data: {
runId: run.id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,9 @@ export class PerformRunExecutionV2Service {

// If the task has an operation, then the next performRunExecution will occur
// when that operation has finished
if (!data.task.operation) {
// Tasks with callbacks enabled will also get processed separately, i.e. when
// they time out, or on valid requests to their callbackUrl
if (!data.task.operation && !data.task.callbackUrl) {
await enqueueRunExecutionV2(run, tx, {
runAt: data.task.delayUntil ?? undefined,
resumeTaskId: data.task.id,
Expand Down
4 changes: 2 additions & 2 deletions apps/webapp/app/services/sources/utils.server.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import crypto from "node:crypto";

export function generateSecret(): string {
return crypto.randomBytes(32).toString("hex");
export function generateSecret(sizeInBytes = 32): string {
return crypto.randomBytes(sizeInBytes).toString("hex");
}
76 changes: 76 additions & 0 deletions apps/webapp/app/services/tasks/processCallbackTimeout.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import { RuntimeEnvironmentType } from "@trigger.dev/database";
import { $transaction, PrismaClient, PrismaClientOrTransaction, prisma } from "~/db.server";
import { enqueueRunExecutionV2 } from "~/models/jobRunExecution.server";
import { logger } from "../logger.server";

type FoundTask = Awaited<ReturnType<typeof findTask>>;

export class ProcessCallbackTimeoutService {
#prismaClient: PrismaClient;

constructor(prismaClient: PrismaClient = prisma) {
this.#prismaClient = prismaClient;
}

public async call(id: string) {
const task = await findTask(this.#prismaClient, id);

if (!task) {
return;
}

if (task.status !== "WAITING" || !task.callbackUrl) {
return;
}

logger.debug("ProcessCallbackTimeoutService.call", { task });

return await this.#failTask(task, "Remote callback timeout - no requests received");
}

async #failTask(task: NonNullable<FoundTask>, error: string) {
await $transaction(this.#prismaClient, async (tx) => {
await tx.taskAttempt.updateMany({
where: {
taskId: task.id,
status: "PENDING",
},
data: {
status: "ERRORED",
error
},
});

await tx.task.update({
where: { id: task.id },
data: {
status: "ERRORED",
completedAt: new Date(),
output: error,
},
});

await this.#resumeRunExecution(task, tx);
});
}

async #resumeRunExecution(task: NonNullable<FoundTask>, prisma: PrismaClientOrTransaction) {
await enqueueRunExecutionV2(task.run, prisma, {
skipRetrying: task.run.environment.type === RuntimeEnvironmentType.DEVELOPMENT,
});
}
}

async function findTask(prisma: PrismaClient, id: string) {
return prisma.task.findUnique({
where: { id },
include: {
run: {
include: {
environment: true,
queue: true,
},
},
},
});
}
Loading