diff --git a/apps/opik-frontend/src/api/playground/createLogPlaygroundProcessor.ts b/apps/opik-frontend/src/api/playground/createLogPlaygroundProcessor.ts new file mode 100644 index 0000000000..5f48925eb5 --- /dev/null +++ b/apps/opik-frontend/src/api/playground/createLogPlaygroundProcessor.ts @@ -0,0 +1,213 @@ +import asyncLib from "async"; +import { v7 } from "uuid"; +import pick from "lodash/pick"; + +import { + LogExperiment, + LogExperimentItem, + LogSpan, + LogTrace, +} from "@/types/playground"; + +import { SPAN_TYPE } from "@/types/traces"; +import api, { + EXPERIMENTS_REST_ENDPOINT, + SPANS_REST_ENDPOINT, + TRACES_REST_ENDPOINT, +} from "@/api/api"; +import { snakeCaseObj } from "@/lib/utils"; +import { getModelProvider } from "@/lib/llm"; +import { createBatchProcessor } from "@/lib/batches"; +import { RunStreamingReturn } from "@/api/playground/useCompletionProxyStreaming"; +import { LLMPromptConfigsType, PROVIDER_MODEL_TYPE } from "@/types/providers"; +import { ProviderMessageType } from "@/types/llm"; + +export interface LogQueueParams extends RunStreamingReturn { + promptId: string; + datasetItemId?: string; + datasetName: string | null; + model: PROVIDER_MODEL_TYPE | ""; + providerMessages: ProviderMessageType[]; + configs: LLMPromptConfigsType; +} + +export interface LogProcessorArgs { + onAddExperimentRegistry: (loggedExperiments: LogExperiment[]) => void; + onError: (error: Error) => void; + onCreateTraces: (traces: LogTrace[]) => void; +} + +export interface LogProcessor { + log: (run: LogQueueParams) => void; +} + +const createBatchTraces = async (traces: LogTrace[]) => { + return api.post(`${TRACES_REST_ENDPOINT}batch`, { + traces: traces.map(snakeCaseObj), + }); +}; + +const createBatchSpans = async (spans: LogSpan[]) => { + return api.post(`${SPANS_REST_ENDPOINT}batch`, { + spans: spans.map(snakeCaseObj), + }); +}; + +const createExperiment = async (experiment: LogExperiment) => { + return api.post(EXPERIMENTS_REST_ENDPOINT, snakeCaseObj(experiment)); +}; + +const createBatchExperimentItems = async ( + experimentItems: LogExperimentItem[], +) => { + await api.post(`${EXPERIMENTS_REST_ENDPOINT}items`, { + experiment_items: experimentItems.map(snakeCaseObj), + }); +}; + +const PLAYGROUND_PROJECT_NAME = "playground"; +const PLAYGROUND_TRACE_SPAN_NAME = "chat_completion_create"; +const USAGE_FIELDS_TO_SEND = [ + "completion_tokens", + "prompt_tokens", + "total_tokens", +]; + +const getTraceFromRun = (run: LogQueueParams): LogTrace => { + return { + id: v7(), + projectName: PLAYGROUND_PROJECT_NAME, + name: PLAYGROUND_TRACE_SPAN_NAME, + startTime: run.startTime, + endTime: run.endTime, + input: { messages: run.providerMessages }, + output: { output: run.result || run.providerError }, + }; +}; + +const getSpanFromRun = (run: LogQueueParams, traceId: string): LogSpan => { + return { + id: v7(), + traceId, + projectName: PLAYGROUND_PROJECT_NAME, + type: SPAN_TYPE.llm, + name: PLAYGROUND_TRACE_SPAN_NAME, + startTime: run.startTime, + endTime: run.endTime, + input: { messages: run.providerMessages }, + output: { choices: run.choices ? run.choices : [] }, + usage: !run.usage ? undefined : pick(run.usage, USAGE_FIELDS_TO_SEND), + metadata: { + created_from: run.model ? getModelProvider(run.model) : "", + usage: run.usage, + model: run.model, + parameters: run.configs, + }, + }; +}; + +const getExperimentFromRun = (run: LogQueueParams): LogExperiment => { + return { + id: v7(), + datasetName: run.datasetName!, + metadata: { + model: run.model, + messages: JSON.stringify(run.providerMessages), + }, + }; +}; + +const getExperimentItemFromRun = ( + run: LogQueueParams, + experimentId: string, + traceId: string, +): LogExperimentItem => { + return { + id: v7(), + datasetItemId: run.datasetItemId!, + experimentId, + traceId, + }; +}; + +const CREATE_EXPERIMENT_CONCURRENCY_RATE = 5; + +const createLogPlaygroundProcessor = ({ + onAddExperimentRegistry, + onError, + onCreateTraces, +}: LogProcessorArgs): LogProcessor => { + const experimentPromptMap: Record = {}; + const experimentRegistry: LogExperiment[] = []; + + const spanBatch = createBatchProcessor(async (spans) => { + try { + await createBatchSpans(spans); + } catch { + onError(new Error("There has been an error with logging spans")); + } + }); + + const traceBatch = createBatchProcessor(async (traces) => { + try { + await createBatchTraces(traces); + onCreateTraces(traces); + } catch { + onError(new Error("There has been an error with logging traces")); + } + }); + + const experimentItemsBatch = createBatchProcessor( + async (experimentItems) => { + try { + await createBatchExperimentItems(experimentItems); + } catch { + onError( + new Error("There has been an error with logging experiment items"), + ); + } + }, + ); + + const experimentsQueue = asyncLib.queue(async (e) => { + try { + await createExperiment(e); + experimentRegistry.push(e); + } catch { + onError(new Error("There has been an error with logging experiments")); + } + }, CREATE_EXPERIMENT_CONCURRENCY_RATE); + + experimentsQueue.drain(() => { + onAddExperimentRegistry(experimentRegistry); + }); + + return { + log: (run: LogQueueParams) => { + const { promptId } = run; + + const trace = getTraceFromRun(run); + const span = getSpanFromRun(run, trace.id); + + // create a missing experiment + if (!experimentPromptMap[promptId]) { + const experiment = getExperimentFromRun(run); + experimentPromptMap[promptId] = experiment.id; + experimentsQueue.push(experiment); + } + + const experimentId = experimentPromptMap[promptId]; + const experimentItem = getExperimentItemFromRun( + run, + experimentId, + trace.id, + ); + + traceBatch.addItem(trace); + spanBatch.addItem(span); + experimentItemsBatch.addItem(experimentItem); + }, + }; +}; + +export default createLogPlaygroundProcessor; diff --git a/apps/opik-frontend/src/api/playground/useCreateOutputTraceAndSpan.ts b/apps/opik-frontend/src/api/playground/useCreateOutputTraceAndSpan.ts deleted file mode 100644 index 5079061982..0000000000 --- a/apps/opik-frontend/src/api/playground/useCreateOutputTraceAndSpan.ts +++ /dev/null @@ -1,94 +0,0 @@ -import { useCallback } from "react"; -import { v7 } from "uuid"; -import pick from "lodash/pick"; - -import useSpanCreateMutation from "@/api/traces/useSpanCreateMutation"; -import useTraceCreateMutation from "@/api/traces/useTraceCreateMutation"; -import { RunStreamingReturn } from "@/api/playground/useCompletionProxyStreaming"; - -import { ProviderMessageType } from "@/types/llm"; -import { useToast } from "@/components/ui/use-toast"; -import { SPAN_TYPE } from "@/types/traces"; -import { LLMPromptConfigsType, PROVIDER_MODEL_TYPE } from "@/types/providers"; - -const PLAYGROUND_TRACE_SPAN_NAME = "chat_completion_create"; - -const USAGE_FIELDS_TO_SEND = [ - "completion_tokens", - "prompt_tokens", - "total_tokens", -]; - -const PLAYGROUND_PROJECT_NAME = "playground"; - -export interface CreateTraceSpanParams extends RunStreamingReturn { - model: PROVIDER_MODEL_TYPE | ""; - providerMessages: ProviderMessageType[]; - configs: LLMPromptConfigsType; -} - -const useCreateOutputTraceAndSpan = () => { - const { toast } = useToast(); - - const { mutateAsync: createSpanMutateAsync } = useSpanCreateMutation(); - const { mutateAsync: createTraceMutateAsync } = useTraceCreateMutation(); - - const createTraceSpan = useCallback( - async ({ - startTime, - endTime, - result, - usage, - providerError, - choices, - model, - providerMessages, - configs, - }: CreateTraceSpanParams) => { - const traceId = v7(); - const spanId = v7(); - - try { - await createTraceMutateAsync({ - id: traceId, - projectName: PLAYGROUND_PROJECT_NAME, - name: PLAYGROUND_TRACE_SPAN_NAME, - startTime, - endTime, - input: { messages: providerMessages }, - output: { output: result || providerError }, - }); - - await createSpanMutateAsync({ - id: spanId, - traceId, - projectName: PLAYGROUND_PROJECT_NAME, - type: SPAN_TYPE.llm, - name: PLAYGROUND_TRACE_SPAN_NAME, - startTime, - endTime, - input: { messages: providerMessages }, - output: { choices }, - usage: !usage ? undefined : pick(usage, USAGE_FIELDS_TO_SEND), - metadata: { - created_from: "openai", - usage, - model, - parameters: configs, - }, - }); - } catch { - toast({ - title: "Error", - description: "There was an error while logging data", - variant: "destructive", - }); - } - }, - [createTraceMutateAsync, createSpanMutateAsync, toast], - ); - - return createTraceSpan; -}; - -export default useCreateOutputTraceAndSpan; diff --git a/apps/opik-frontend/src/api/traces/useSpanCreateMutation.ts b/apps/opik-frontend/src/api/traces/useSpanCreateMutation.ts deleted file mode 100644 index 8c1aa69a8d..0000000000 --- a/apps/opik-frontend/src/api/traces/useSpanCreateMutation.ts +++ /dev/null @@ -1,67 +0,0 @@ -import { useMutation, useQueryClient } from "@tanstack/react-query"; -import { AxiosError } from "axios"; -import get from "lodash/get"; - -import api, { SPANS_KEY, SPANS_REST_ENDPOINT } from "@/api/api"; -import { useToast } from "@/components/ui/use-toast"; -import { SPAN_TYPE } from "@/types/traces"; - -import { JsonNode, UsageType } from "@/types/shared"; -import { snakeCaseObj } from "@/lib/utils"; - -type UseSpanCreateMutationParams = { - id: string; - projectName?: string; - traceId: string; - parentSpanId?: string; - name: string; - type: SPAN_TYPE; - startTime: string; - endTime?: string; - input?: JsonNode; - output?: JsonNode; - model?: string; - provider?: string; - tags?: string[]; - usage?: UsageType; - metadata?: object; -}; - -const useSpanCreateMutation = () => { - const queryClient = useQueryClient(); - const { toast } = useToast(); - - return useMutation({ - mutationFn: async (span: UseSpanCreateMutationParams) => { - const { data } = await api.post(SPANS_REST_ENDPOINT, snakeCaseObj(span)); - - return data; - }, - onError: (error: AxiosError) => { - const message = get( - error, - ["response", "data", "message"], - error.message, - ); - - toast({ - title: "Error", - description: message, - variant: "destructive", - }); - }, - onSettled: (data, error, variables) => { - if (variables.projectName) { - queryClient.invalidateQueries({ - queryKey: ["projects"], - }); - } - - queryClient.invalidateQueries({ - queryKey: [SPANS_KEY], - }); - }, - }); -}; - -export default useSpanCreateMutation; diff --git a/apps/opik-frontend/src/api/traces/useTraceCreateMutation.ts b/apps/opik-frontend/src/api/traces/useTraceCreateMutation.ts deleted file mode 100644 index ca0db17535..0000000000 --- a/apps/opik-frontend/src/api/traces/useTraceCreateMutation.ts +++ /dev/null @@ -1,58 +0,0 @@ -import { useMutation, useQueryClient } from "@tanstack/react-query"; -import { AxiosError } from "axios"; -import get from "lodash/get"; - -import api, { TRACES_KEY, TRACES_REST_ENDPOINT } from "@/api/api"; -import { useToast } from "@/components/ui/use-toast"; - -import { snakeCaseObj } from "@/lib/utils"; -import { JsonNode } from "@/types/shared"; - -type UseTraceCreateMutationParams = { - id: string; - projectName?: string; - name: string; - startTime: string; - endTime?: string; - input?: JsonNode; - output?: JsonNode; - tags?: string[]; - metadata?: object; -}; - -const useTraceCreateMutation = () => { - const queryClient = useQueryClient(); - const { toast } = useToast(); - - return useMutation({ - mutationFn: async (trace: UseTraceCreateMutationParams) => { - return api.post(TRACES_REST_ENDPOINT, snakeCaseObj(trace)); - }, - onError: (error: AxiosError) => { - const message = get( - error, - ["response", "data", "message"], - error.message, - ); - - toast({ - title: "Error", - description: message, - variant: "destructive", - }); - }, - onSettled: (data, error, variables) => { - if (variables.projectName) { - queryClient.invalidateQueries({ - queryKey: ["projects"], - }); - } - - queryClient.invalidateQueries({ - queryKey: [TRACES_KEY], - }); - }, - }); -}; - -export default useTraceCreateMutation; diff --git a/apps/opik-frontend/src/components/pages/PlaygroundPage/PlaygroundOutputs/PlaygroundOutputActions/PlaygroundOutputActions.tsx b/apps/opik-frontend/src/components/pages/PlaygroundPage/PlaygroundOutputs/PlaygroundOutputActions/PlaygroundOutputActions.tsx index 7801baa8ef..729b9396d6 100644 --- a/apps/opik-frontend/src/components/pages/PlaygroundPage/PlaygroundOutputs/PlaygroundOutputActions/PlaygroundOutputActions.tsx +++ b/apps/opik-frontend/src/components/pages/PlaygroundPage/PlaygroundOutputs/PlaygroundOutputActions/PlaygroundOutputActions.tsx @@ -59,9 +59,12 @@ const PlaygroundOutputActions = ({ })); }, [datasets]); + const datasetName = datasets?.find((ds) => ds.id === datasetId)?.name || null; + const { stopAll, runAll, isRunning } = useActionButtonActions({ workspaceName, datasetItems, + datasetName, }); const loadMoreHandler = useCallback(() => setIsLoadedMore(true), []); diff --git a/apps/opik-frontend/src/components/pages/PlaygroundPage/PlaygroundOutputs/PlaygroundOutputActions/useActionButtonActions.ts b/apps/opik-frontend/src/components/pages/PlaygroundPage/PlaygroundOutputs/PlaygroundOutputActions/useActionButtonActions.ts index f3fbc5ccce..7fcc8f94c4 100644 --- a/apps/opik-frontend/src/components/pages/PlaygroundPage/PlaygroundOutputs/PlaygroundOutputActions/useActionButtonActions.ts +++ b/apps/opik-frontend/src/components/pages/PlaygroundPage/PlaygroundOutputs/PlaygroundOutputActions/useActionButtonActions.ts @@ -1,70 +1,18 @@ -import { useCallback, useRef, useState } from "react"; +import { useCallback, useMemo, useRef, useState } from "react"; import asyncLib from "async"; -import mustache from "mustache"; -import isUndefined from "lodash/isUndefined"; - -import get from "lodash/get"; -import set from "lodash/set"; -import isObject from "lodash/isObject"; -import cloneDeep from "lodash/cloneDeep"; +import { useQueryClient } from "@tanstack/react-query"; import { DatasetItem } from "@/types/datasets"; -import { LLMMessage, ProviderMessageType } from "@/types/llm"; import { PlaygroundPromptType } from "@/types/playground"; -import { - usePromptIds, - usePromptMap, - useResetOutputMap, - useUpdateOutput, -} from "@/store/PlaygroundStore"; -import useCompletionProxyStreaming from "@/api/playground/useCompletionProxyStreaming"; -import useCreateOutputTraceAndSpan, { - CreateTraceSpanParams, -} from "@/api/playground/useCreateOutputTraceAndSpan"; -import { getPromptMustacheTags } from "@/lib/prompt"; - -const LIMIT_STREAMING_CALLS = 5; -const LIMIT_LOG_CALLS = 2; - -const serializeTags = (datasetItem: DatasetItem["data"], tags: string[]) => { - const newDatasetItem = cloneDeep(datasetItem); - - tags.forEach((tag) => { - const value = get(newDatasetItem, tag); - set(newDatasetItem, tag, isObject(value) ? JSON.stringify(value) : value); - }); - - return newDatasetItem; -}; - -export const transformMessageIntoProviderMessage = ( - message: LLMMessage, - datasetItem: DatasetItem["data"] = {}, -): ProviderMessageType => { - const messageTags = getPromptMustacheTags(message.content); - const serializedDatasetItem = serializeTags(datasetItem, messageTags); - - const notDefinedVariables = messageTags.filter((tag) => - isUndefined(get(serializedDatasetItem, tag)), - ); +import { usePromptIds, useResetOutputMap } from "@/store/PlaygroundStore"; - if (notDefinedVariables.length > 0) { - throw new Error(`${notDefinedVariables.join(", ")} not defined`); - } +import { useToast } from "@/components/ui/use-toast"; +import createLogPlaygroundProcessor, { + LogProcessorArgs, +} from "@/api/playground/createLogPlaygroundProcessor"; +import usePromptDatasetItemCombination from "@/components/pages/PlaygroundPage/PlaygroundOutputs/PlaygroundOutputActions/usePromptDatasetItemCombination"; - return { - role: message.role, - content: mustache.render( - message.content, - serializedDatasetItem, - {}, - { - // avoid escaping of a mustache - escape: (val: string) => val, - }, - ), - }; -}; +const LIMIT_STREAMING_CALLS = 5; interface DatasetItemPromptCombination { datasetItem?: DatasetItem; @@ -74,146 +22,133 @@ interface DatasetItemPromptCombination { interface UseActionButtonActionsArguments { datasetItems: DatasetItem[]; workspaceName: string; + datasetName: string | null; } const useActionButtonActions = ({ datasetItems, workspaceName, + datasetName, }: UseActionButtonActionsArguments) => { - const [isRunning, setIsRunning] = useState(false); + const queryClient = useQueryClient(); - const isToStopRef = useRef(false); - const abortControllersOngoingRef = useRef(new Map()); + const { toast } = useToast(); - const promptMap = usePromptMap(); + const [isRunning, setIsRunning] = useState(false); + const [isToStop, setIsToStop] = useState(false); const promptIds = usePromptIds(); - const updateOutput = useUpdateOutput(); + const abortControllersRef = useRef(new Map()); + const resetOutputMap = useResetOutputMap(); - const createTraceSpan = useCreateOutputTraceAndSpan(); - const runStreaming = useCompletionProxyStreaming({ - workspaceName, - }); + const resetState = useCallback(() => { + resetOutputMap(); + abortControllersRef.current.clear(); + setIsRunning(false); + }, [resetOutputMap]); const stopAll = useCallback(() => { // nothing to stop - if (abortControllersOngoingRef.current.size === 0) { + if (abortControllersRef.current.size === 0) { return; } - isToStopRef.current = true; - abortControllersOngoingRef.current.forEach((controller) => - controller.abort(), - ); + setIsToStop(true); + abortControllersRef.current.forEach((controller) => controller.abort()); - abortControllersOngoingRef.current.clear(); + abortControllersRef.current.clear(); }, []); - const runAll = useCallback(async () => { - resetOutputMap(); - setIsRunning(true); - - const asyncLogQueue = asyncLib.queue( - createTraceSpan, - LIMIT_LOG_CALLS, - ); - - let combinations: DatasetItemPromptCombination[] = []; - - if (datasetItems.length > 0) { - combinations = datasetItems.flatMap((di) => - promptIds.map((promptId) => ({ - datasetItem: di, - prompt: promptMap[promptId], - })), - ); - } else if (promptIds.length > 0) { - combinations = promptIds.map((promptId) => ({ - prompt: promptMap[promptId], - })); - } - - const processCombination = async ({ - datasetItem, - prompt, - }: DatasetItemPromptCombination) => { - if (isToStopRef.current) { - return; - } - - const controller = new AbortController(); - - const datasetItemId = datasetItem?.id || ""; - const datasetItemData = datasetItem?.data || {}; - - const key = datasetItemId ? `${datasetItemId}-${prompt.id}` : prompt.id; - abortControllersOngoingRef.current.set(key, controller); + const showMessageExperimentsLogged = useCallback( + (experimentCount: number) => { + const title = + experimentCount === 1 ? "Experiment started" : "Experiments started"; + + const description = + experimentCount === 1 + ? "The experiment started successfully" + : `${experimentCount} experiments started successfully`; + + toast({ + title, + description, + }); + }, + [toast], + ); - try { - updateOutput(prompt.id, datasetItemId, { - isLoading: true, + const logProcessorHandlers: LogProcessorArgs = useMemo(() => { + return { + onAddExperimentRegistry: (experiments) => { + // to check if all experiments have been created + if (experiments.length === promptIds.length) { + showMessageExperimentsLogged(experiments.length); + queryClient.invalidateQueries({ + queryKey: ["experiments"], + }); + } + }, + onError: (e) => { + toast({ + title: "Error", + variant: "destructive", + description: e.message, }); - - const providerMessages = prompt.messages.map((m) => - transformMessageIntoProviderMessage(m, datasetItemData), - ); - - const run = await runStreaming({ - model: prompt.model, - messages: providerMessages, - configs: prompt.configs, - signal: controller.signal, - onAddChunk: (o) => { - updateOutput(prompt.id, datasetItemId, { - value: o, - }); - }, + }, + onCreateTraces: () => { + queryClient.invalidateQueries({ + queryKey: [["projects"]], }); + }, + }; + }, [queryClient, promptIds.length, showMessageExperimentsLogged, toast]); - const error = run.opikError || run.providerError; + const addAbortController = useCallback( + (key: string, value: AbortController) => { + abortControllersRef.current.set(key, value); + }, + [], + ); - updateOutput(prompt.id, datasetItemId, { - isLoading: false, - }); + const deleteAbortController = useCallback( + (key: string) => abortControllersRef.current.delete(key), + [], + ); - asyncLogQueue.push({ - ...run, - providerMessages, - configs: prompt.configs, - model: prompt.model, - }); + const { createCombinations, processCombination } = + usePromptDatasetItemCombination({ + workspaceName, + isToStop, + datasetItems, + datasetName, + addAbortController, + deleteAbortController, + }); - if (error) { - throw new Error(error); - } - } catch (error) { - const typedError = error as Error; + const runAll = useCallback(async () => { + resetState(); + setIsRunning(true); - updateOutput(prompt.id, datasetItemId, { - value: typedError.message, - isLoading: false, - }); - } - }; + const logProcessor = createLogPlaygroundProcessor(logProcessorHandlers); + + const combinations = createCombinations(); asyncLib.mapLimit( combinations, LIMIT_STREAMING_CALLS, - processCombination, + async (combination: DatasetItemPromptCombination) => + processCombination(combination, logProcessor), () => { setIsRunning(false); - isToStopRef.current = false; - abortControllersOngoingRef.current.clear(); + setIsToStop(false); + abortControllersRef.current.clear(); }, ); }, [ - resetOutputMap, - promptIds, - datasetItems, - promptMap, - createTraceSpan, - runStreaming, - updateOutput, + resetState, + createCombinations, + processCombination, + logProcessorHandlers, ]); return { diff --git a/apps/opik-frontend/src/components/pages/PlaygroundPage/PlaygroundOutputs/PlaygroundOutputActions/usePromptDatasetItemCombination.ts b/apps/opik-frontend/src/components/pages/PlaygroundPage/PlaygroundOutputs/PlaygroundOutputActions/usePromptDatasetItemCombination.ts new file mode 100644 index 0000000000..75197e8bca --- /dev/null +++ b/apps/opik-frontend/src/components/pages/PlaygroundPage/PlaygroundOutputs/PlaygroundOutputActions/usePromptDatasetItemCombination.ts @@ -0,0 +1,199 @@ +import { useCallback, useEffect, useRef } from "react"; +import { LogProcessor } from "@/api/playground/createLogPlaygroundProcessor"; +import { DatasetItem } from "@/types/datasets"; +import { PlaygroundPromptType } from "@/types/playground"; +import { + usePromptIds, + usePromptMap, + useUpdateOutput, +} from "@/store/PlaygroundStore"; +import useCompletionProxyStreaming from "@/api/playground/useCompletionProxyStreaming"; +import { LLMMessage, ProviderMessageType } from "@/types/llm"; +import { getPromptMustacheTags } from "@/lib/prompt"; +import isUndefined from "lodash/isUndefined"; +import get from "lodash/get"; +import mustache from "mustache"; +import cloneDeep from "lodash/cloneDeep"; +import set from "lodash/set"; +import isObject from "lodash/isObject"; + +export interface DatasetItemPromptCombination { + datasetItem?: DatasetItem; + prompt: PlaygroundPromptType; +} + +const serializeTags = (datasetItem: DatasetItem["data"], tags: string[]) => { + const newDatasetItem = cloneDeep(datasetItem); + + tags.forEach((tag) => { + const value = get(newDatasetItem, tag); + set(newDatasetItem, tag, isObject(value) ? JSON.stringify(value) : value); + }); + + return newDatasetItem; +}; + +const transformMessageIntoProviderMessage = ( + message: LLMMessage, + datasetItem: DatasetItem["data"] = {}, +): ProviderMessageType => { + const messageTags = getPromptMustacheTags(message.content); + const serializedDatasetItem = serializeTags(datasetItem, messageTags); + + const notDefinedVariables = messageTags.filter((tag) => + isUndefined(get(serializedDatasetItem, tag)), + ); + + if (notDefinedVariables.length > 0) { + throw new Error(`${notDefinedVariables.join(", ")} not defined`); + } + + return { + role: message.role, + content: mustache.render( + message.content, + serializedDatasetItem, + {}, + { + // avoid escaping of a mustache + escape: (val: string) => val, + }, + ), + }; +}; + +interface UsePromptDatasetItemCombinationArgs { + datasetItems: DatasetItem[]; + isToStop: boolean; + workspaceName: string; + datasetName: string | null; + addAbortController: (key: string, value: AbortController) => void; + deleteAbortController: (key: string) => void; +} + +const usePromptDatasetItemCombination = ({ + datasetItems, + isToStop, + workspaceName, + datasetName, + addAbortController, + deleteAbortController, +}: UsePromptDatasetItemCombinationArgs) => { + const updateOutput = useUpdateOutput(); + + // the reason why we need ref here is that the value is taken in a deep callback + // the prop is just taken as the value on the moment of creation + const isToStopRef = useRef(isToStop); + + const runStreaming = useCompletionProxyStreaming({ + workspaceName, + }); + + useEffect(() => { + isToStopRef.current = isToStop; + }, [isToStop]); + + const promptIds = usePromptIds(); + const promptMap = usePromptMap(); + + const createCombinations = useCallback((): DatasetItemPromptCombination[] => { + if (datasetItems.length > 0 && promptIds.length > 0) { + return datasetItems.flatMap((di) => + promptIds.map((promptId) => ({ + datasetItem: di, + prompt: promptMap[promptId], + })), + ); + } + + return promptIds.map((promptId) => ({ + prompt: promptMap[promptId], + })); + }, [datasetItems, promptMap, promptIds]); + + const processCombination = useCallback( + async ( + { datasetItem, prompt }: DatasetItemPromptCombination, + logProcessor: LogProcessor, + ) => { + if (isToStopRef.current) { + return; + } + + const controller = new AbortController(); + + const datasetItemId = datasetItem?.id || ""; + const datasetItemData = datasetItem?.data || {}; + const key = `${datasetItemId}-${prompt.id}`; + + addAbortController(key, controller); + + try { + updateOutput(prompt.id, datasetItemId, { + isLoading: true, + }); + + const providerMessages = prompt.messages.map((m) => + transformMessageIntoProviderMessage(m, datasetItemData), + ); + + const run = await runStreaming({ + model: prompt.model, + messages: providerMessages, + configs: prompt.configs, + signal: controller.signal, + onAddChunk: (o) => { + updateOutput(prompt.id, datasetItemId, { + value: o, + }); + }, + }); + + const error = run.opikError || run.providerError; + + updateOutput(prompt.id, datasetItemId, { + isLoading: false, + }); + + logProcessor.log({ + ...run, + providerMessages, + configs: prompt.configs, + model: prompt.model, + + promptId: prompt.id, + datasetName, + datasetItemId: datasetItemId, + }); + + if (error) { + throw new Error(error); + } + } catch (error) { + const typedError = error as Error; + + updateOutput(prompt.id, datasetItemId, { + value: typedError.message, + isLoading: false, + }); + } finally { + deleteAbortController(key); + } + }, + + [ + datasetName, + runStreaming, + updateOutput, + addAbortController, + deleteAbortController, + ], + ); + + return { + createCombinations, + processCombination, + }; +}; + +export default usePromptDatasetItemCombination; diff --git a/apps/opik-frontend/src/lib/batches.ts b/apps/opik-frontend/src/lib/batches.ts new file mode 100644 index 0000000000..fb60287638 --- /dev/null +++ b/apps/opik-frontend/src/lib/batches.ts @@ -0,0 +1,54 @@ +export const createBatchProcessor = ( + processCallback: (items: T[]) => void, + maxBatchSize = 20, + flushInterval = 2000, +) => { + let currentBatch: T[] = []; + let flushTimer: NodeJS.Timeout | null = null; + + const processCurrentBatch = () => { + if (currentBatch.length > 0) { + processCallback(currentBatch); + currentBatch = []; + } + }; + + const flushBatch = () => { + processCurrentBatch(); + + if (flushTimer) { + clearTimeout(flushTimer); + flushTimer = null; + } + }; + + const startFlushTimer = () => { + if (flushTimer) { + clearTimeout(flushTimer); + } + + flushTimer = setTimeout(() => { + flushBatch(); + }, flushInterval); + }; + + const addItemToBatch = (item: T) => { + currentBatch.push(item); + + if (currentBatch.length >= maxBatchSize) { + processCurrentBatch(); + + if (flushTimer) { + clearTimeout(flushTimer); + flushTimer = null; + } + } else { + startFlushTimer(); + } + }; + + return { + addItem: addItemToBatch, + flush: flushBatch, + }; +}; diff --git a/apps/opik-frontend/src/types/playground.ts b/apps/opik-frontend/src/types/playground.ts index 25308ad6e2..d9548e17c2 100644 --- a/apps/opik-frontend/src/types/playground.ts +++ b/apps/opik-frontend/src/types/playground.ts @@ -1,7 +1,9 @@ -import { UsageType } from "@/types/shared"; -import { LLMMessage } from "@/types/llm"; import { HttpStatusCode } from "axios"; + +import { JsonNode, UsageType } from "@/types/shared"; +import { LLMMessage, ProviderMessageType } from "@/types/llm"; import { LLMPromptConfigsType, PROVIDER_MODEL_TYPE } from "@/types/providers"; +import { SPAN_TYPE } from "@/types/traces"; export interface PlaygroundPromptType { name: string; @@ -40,3 +42,48 @@ export type ChatCompletionResponse = | ChatCompletionOpikErrorMessageType | ChatCompletionSuccessMessageType | ChatCompletionProviderErrorMessageType; + +export interface LogTrace { + id: string; + projectName: string; + name: string; + startTime: string; + endTime: string; + input: { messages: ProviderMessageType[] }; + output: { output: string | null }; +} + +export interface LogSpan { + id: string; + traceId: string; + projectName: string; + type: SPAN_TYPE.llm; + name: string; + startTime: string; + endTime: string; + input: { messages: ProviderMessageType[] }; + output: { choices: ChatCompletionMessageChoiceType[] }; + usage?: UsageType | null; + metadata: { + created_from: string; + usage: UsageType | null; + model: string; + parameters: LLMPromptConfigsType; + }; +} + +export interface LogExperiment { + id: string; + datasetName: string; + name?: string; + metadata?: object; +} + +export type LogExperimentItem = { + id: string; + experimentId: string; + datasetItemId: string; + traceId: string; +} & { + [inputOutputField: string]: JsonNode; +};